Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def forward(
cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,
min_seqlen_q=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change correctly fixes the accuracy issue by hardcoding min_seqlen_q=1. To prevent this from being changed back to a dynamic value in the future, which could reintroduce this critical bug, it would be beneficial to add a comment explaining why this value is hardcoded. This will improve the long-term maintainability of the code.

Suggested change
min_seqlen_q=1,
# AITeR MHA kernel requires min_seqlen_q >= 1. Using
# `attn_metadata.prefill_metadata.min_query_len` can be 0,
# causing accuracy issues. See #28598.
min_seqlen_q=1,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a minor suggestion, maybe reorder_batch_threshold would be more appropriate here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sammysun0711 what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa, @ganyi1996ppo, thanks for your feedback.
For metadata preparation process, I checked metadata for prefill and extend, it seems reasonable:

For @ganyi1996ppo suggested reorder_batch_threshold, it is AiterFlashAttentionMetadata's variable with default int value 1 used for split_decodes_prefills_and_extends:

split_ret = split_decodes_prefills_and_extends(

It is not very intuitive to pass it min_seqlen_q, could you please elaborate more about this? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to reorder the request, for the request that query_len <= reorder_batch_threshold will be deemed as decode and be passed to the paged attention to process. So I think reorder_batch_threshold can be more intuitive to passed to the metadata rather than 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And of course 1 is correct enough for the MHA backend, since prefill or extend path will receive no request with query_len == 1, so, pass 1 directly is also LGTM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per offline sync with @ganyi1996ppo, we agree that:

  1. We can use min_seq_len=1 in this PR to fix accuracy issue, since currently Aiter MHA will not receive request with query_len == 1 for both prefill and extent path. And the usage is also consistent with code before PR25763 refactor.

  2. reorder_batch_threshold aim to address case3: UNIFORM: uniform multi-token queries mentioned in [ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763 (comment) (Not supported by Aiter attention backend yet), and its value may change during runtime for MTP case, so it will be better to leave it for the future work.

@tjtanaa, what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sammysun0711 Sure. Let's go with this approach. Thank you for the updates.

dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
Expand Down Expand Up @@ -759,7 +759,7 @@ def forward(
cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
min_seqlen_q=attn_metadata.extend_metadata.min_query_len,
min_seqlen_q=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the prefill case, this change correctly fixes the accuracy issue for the extend phase. Adding a comment here as well would be beneficial for future maintainability and to prevent regressions of this critical bug.

Suggested change
min_seqlen_q=1,
# AITeR MHA kernel requires min_seqlen_q >= 1. Using
# `attn_metadata.extend_metadata.min_query_len` can be 0,
# causing accuracy issues. See #28598.
min_seqlen_q=1,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be an issue with the metadata preparation steps?

block_table=attn_metadata.block_table[
num_decodes : num_decodes + num_extends
],
Expand Down