-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[ROCm][Qwen3-32B] Fix AITER MHA accuracy issue cause by #25763 #28670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm][Qwen3-32B] Fix AITER MHA accuracy issue cause by #25763 #28670
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a critical accuracy bug in the AITeR MHA backend on ROCm, which was introduced in a previous refactoring. The issue was caused by passing a potentially zero value for min_seqlen_q to the attention kernel, which expects a value of at least 1. The fix correctly hardcodes this parameter to 1 for both prefill and extend phases, restoring the correct behavior and aligning accuracy with other backends. My review includes suggestions to add comments to the code to explain this hardcoded value, which will help prevent future regressions and improve maintainability.
| max_seqlen_q=attn_metadata.prefill_metadata.max_query_len, | ||
| max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len, | ||
| min_seqlen_q=attn_metadata.prefill_metadata.min_query_len, | ||
| min_seqlen_q=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change correctly fixes the accuracy issue by hardcoding min_seqlen_q=1. To prevent this from being changed back to a dynamic value in the future, which could reintroduce this critical bug, it would be beneficial to add a comment explaining why this value is hardcoded. This will improve the long-term maintainability of the code.
| min_seqlen_q=1, | |
| # AITeR MHA kernel requires min_seqlen_q >= 1. Using | |
| # `attn_metadata.prefill_metadata.min_query_len` can be 0, | |
| # causing accuracy issues. See #28598. | |
| min_seqlen_q=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a minor suggestion, maybe reorder_batch_threshold would be more appropriate here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sammysun0711 what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjtanaa, @ganyi1996ppo, thanks for your feedback.
For metadata preparation process, I checked metadata for prefill and extend, it seems reasonable:
min_query_len=query_lens_for_prefill.min().item(), min_query_len=query_lens_for_extend.min().item(),
For @ganyi1996ppo suggested reorder_batch_threshold, it is AiterFlashAttentionMetadata's variable with default int value 1 used for split_decodes_prefills_and_extends:
| split_ret = split_decodes_prefills_and_extends( |
It is not very intuitive to pass it
min_seqlen_q, could you please elaborate more about this? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used to reorder the request, for the request that query_len <= reorder_batch_threshold will be deemed as decode and be passed to the paged attention to process. So I think reorder_batch_threshold can be more intuitive to passed to the metadata rather than 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And of course 1 is correct enough for the MHA backend, since prefill or extend path will receive no request with query_len == 1, so, pass 1 directly is also LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per offline sync with @ganyi1996ppo, we agree that:
-
We can use
min_seq_len=1in this PR to fix accuracy issue, since currently Aiter MHA will not receive request withquery_len == 1for both prefill and extent path. And the usage is also consistent with code before PR25763 refactor. -
reorder_batch_thresholdaim to address case3:UNIFORM: uniform multi-token queriesmentioned in [ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763 (comment) (Not supported by Aiter attention backend yet), and its value may change during runtime for MTP case, so it will be better to leave it for the future work.
@tjtanaa, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sammysun0711 Sure. Let's go with this approach. Thank you for the updates.
| max_seqlen_q=attn_metadata.extend_metadata.max_query_len, | ||
| max_seqlen_k=attn_metadata.extend_metadata.max_seq_len, | ||
| min_seqlen_q=attn_metadata.extend_metadata.min_query_len, | ||
| min_seqlen_q=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the prefill case, this change correctly fixes the accuracy issue for the extend phase. Adding a comment here as well would be beneficial for future maintainability and to prevent regressions of this critical bug.
| min_seqlen_q=1, | |
| # AITeR MHA kernel requires min_seqlen_q >= 1. Using | |
| # `attn_metadata.extend_metadata.min_query_len` can be 0, | |
| # causing accuracy issues. See #28598. | |
| min_seqlen_q=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be an issue with the metadata preparation steps?
|
@ganyi1996ppo @wuhuikx Could you take a look at this? Is this compatible with new AITER version as well? Thank you. |
|
+1 we've been seeing this issue and this PR fixes the accuracy issue we were observing |
|
LGTM |
It seems both AITER MHA function API and underlying CK pre-condition check does not change with new AITER version. Here are
|
|
@sammysun0711 @tjtanaa could you rebase again? |
…_attn_varlen_func for pure prefill and extend phase Signed-off-by: Xiake Sun <xiake.sun@amd.com>
Head branch was pushed to by a user without write access
5ead57c to
91bc7e8
Compare
Purpose
This PR aim to fix Qwen3-32B accuracy issue w/ AITER MHA reported by #28598 by passing correct
min_seqlen_qvalue inaiter.flash_attn_varlen_funcin pure prefill and extend phase.After further investigation, I found #25763 introduced this issue during refactor of AITER MHA backend w/ default AITER version
9716b1b8:vllm/docker/Dockerfile.rocm_base
Line 10 in 3eb0c26
Before the refactor,
min_seq_len=1was passed toaiter.flash_attn_varlen_funcin prefill phase.vllm/vllm/v1/attention/backends/rocm_aiter_fa.py
Lines 184 to 198 in 2918c1b
#25763 split into pure prefill/extends/decode phase, the min_seq_len set by
attn_metadata.prefill_metadata.min_query_lenandattn_metadata.extend_metadata.min_query_lenin following line:vllm/vllm/v1/attention/backends/rocm_aiter_fa.py
Line 756 in dc93717
vllm/vllm/v1/attention/backends/rocm_aiter_fa.py
Line 786 in dc93717
If wrong
min_seq_lenvalue pass toaiter.flash_attn_varlen_func, it will trigger following pre-condition check in ckfmha_fwd_pagedkv_kernelto return without kernel execution, hence cause the accuracy issue: https://github.com/ROCm/composable_kernel/blob/9f33b7cfd3df3fcfd540f7633b0abd7019935761/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp#L1175-L1181Test Plan
Accuracy test: lm-eval with gsm8k on MI308X
Base docker image: rocm/vllm-dev/nightly_main_20251113
lm_eval command
Test Result
AITER MHA before fix:
AITER MHA after fix:
Triton Unified Attention (default)
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.