-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[ROCm][Qwen3-32B] Fix AITER MHA accuracy issue cause by #25763 #28670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -729,7 +729,7 @@ def forward( | |||||||||||
| cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc, | ||||||||||||
| max_seqlen_q=attn_metadata.prefill_metadata.max_query_len, | ||||||||||||
| max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len, | ||||||||||||
| min_seqlen_q=attn_metadata.prefill_metadata.min_query_len, | ||||||||||||
| min_seqlen_q=1, | ||||||||||||
| dropout_p=0.0, | ||||||||||||
| softmax_scale=self.scale, | ||||||||||||
| causal=True, | ||||||||||||
|
|
@@ -759,7 +759,7 @@ def forward( | |||||||||||
| cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc, | ||||||||||||
| max_seqlen_q=attn_metadata.extend_metadata.max_query_len, | ||||||||||||
| max_seqlen_k=attn_metadata.extend_metadata.max_seq_len, | ||||||||||||
| min_seqlen_q=attn_metadata.extend_metadata.min_query_len, | ||||||||||||
| min_seqlen_q=1, | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the prefill case, this change correctly fixes the accuracy issue for the extend phase. Adding a comment here as well would be beneficial for future maintainability and to prevent regressions of this critical bug.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could it be an issue with the metadata preparation steps? |
||||||||||||
| block_table=attn_metadata.block_table[ | ||||||||||||
| num_decodes : num_decodes + num_extends | ||||||||||||
| ], | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change correctly fixes the accuracy issue by hardcoding
min_seqlen_q=1. To prevent this from being changed back to a dynamic value in the future, which could reintroduce this critical bug, it would be beneficial to add a comment explaining why this value is hardcoded. This will improve the long-term maintainability of the code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a minor suggestion, maybe
reorder_batch_thresholdwould be more appropriate here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sammysun0711 what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjtanaa, @ganyi1996ppo, thanks for your feedback.
For metadata preparation process, I checked metadata for prefill and extend, it seems reasonable:
vllm/vllm/v1/attention/backends/rocm_aiter_fa.py
Line 341 in 622e610
vllm/vllm/v1/attention/backends/rocm_aiter_fa.py
Line 409 in 72f5119
For @ganyi1996ppo suggested
reorder_batch_threshold, it isAiterFlashAttentionMetadata's variable with default int value1used forsplit_decodes_prefills_and_extends:vllm/vllm/v1/attention/backends/rocm_aiter_fa.py
Line 304 in 72f5119
It is not very intuitive to pass it
min_seqlen_q, could you please elaborate more about this? Thanks!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used to reorder the request, for the request that
query_len <= reorder_batch_thresholdwill be deemed as decode and be passed to the paged attention to process. So I thinkreorder_batch_thresholdcan be more intuitive to passed to the metadata rather than1.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And of course
1is correct enough for the MHA backend, since prefill or extend path will receive no request withquery_len == 1, so, pass1directly is also LGTM.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per offline sync with @ganyi1996ppo, we agree that:
We can use
min_seq_len=1in this PR to fix accuracy issue, since currently Aiter MHA will not receive request withquery_len == 1for both prefill and extent path. And the usage is also consistent with code before PR25763 refactor.reorder_batch_thresholdaim to address case3:UNIFORM: uniform multi-token queriesmentioned in [ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763 (comment) (Not supported by Aiter attention backend yet), and its value may change during runtime for MTP case, so it will be better to leave it for the future work.@tjtanaa, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sammysun0711 Sure. Let's go with this approach. Thank you for the updates.