Skip to content

Conversation

@sammysun0711
Copy link
Contributor

@sammysun0711 sammysun0711 commented Nov 13, 2025

Purpose

This PR aim to fix Qwen3-32B accuracy issue w/ AITER MHA reported by #28598 by passing correct min_seqlen_q value in aiter.flash_attn_varlen_func in pure prefill and extend phase.

After further investigation, I found #25763 introduced this issue during refactor of AITER MHA backend w/ default AITER version 9716b1b8:

ARG AITER_BRANCH="9716b1b8"

Before the refactor, min_seq_len=1 was passed to aiter.flash_attn_varlen_func in prefill phase.

output = aiter.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
min_seqlen_q=1,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
out=out,
)

#25763 split into pure prefill/extends/decode phase, the min_seq_len set by attn_metadata.prefill_metadata.min_query_len and attn_metadata.extend_metadata.min_query_len in following line:

min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,

min_seqlen_q=attn_metadata.extend_metadata.min_query_len,

If wrong min_seq_len value pass to aiter.flash_attn_varlen_func, it will trigger following pre-condition check in ck fmha_fwd_pagedkv_kernel to return without kernel execution, hence cause the accuracy issue: https://github.com/ROCm/composable_kernel/blob/9f33b7cfd3df3fcfd540f7633b0abd7019935761/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp#L1175-L1181

Test Plan

Accuracy test: lm-eval with gsm8k on MI308X
Base docker image: rocm/vllm-dev/nightly_main_20251113

#!/bin/bash
rm -rf /root/.cache/vllm/
echo "Qwen/Qwen3-32B"
export VLLM_RPC_TIMEOUT=1800000
export SAFETENSORS_FAST_GPU=1
export MODEL_PATH=Qwen/Qwen3-32B

VLLM_ROCM_USE_AITER=1 \
vllm serve $MODEL_PATH \
-tp 2 \
--trust-remote-code \
--disable-log-requests

lm_eval command

#!/bin/bash
lm_eval \
--model local-completions \
--tasks gsm8k \
--model_args model=Qwen/Qwen3-32B,base_url=http://127.0.0.1:8000/v1/completions \
--batch_size 100

Test Result

AITER MHA before fix:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.4769 ± 0.0138
strict-match 5 exact_match 0.5428 ± 0.0137

AITER MHA after fix:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6293 ± 0.0133
strict-match 5 exact_match 0.7210 ± 0.0124

Triton Unified Attention (default)

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6300 ± 0.0133
strict-match 5 exact_match 0.7392 ± 0.0124

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm v1 labels Nov 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a critical accuracy bug in the AITeR MHA backend on ROCm, which was introduced in a previous refactoring. The issue was caused by passing a potentially zero value for min_seqlen_q to the attention kernel, which expects a value of at least 1. The fix correctly hardcodes this parameter to 1 for both prefill and extend phases, restoring the correct behavior and aligning accuracy with other backends. My review includes suggestions to add comments to the code to explain this hardcoded value, which will help prevent future regressions and improve maintainability.

max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,
min_seqlen_q=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change correctly fixes the accuracy issue by hardcoding min_seqlen_q=1. To prevent this from being changed back to a dynamic value in the future, which could reintroduce this critical bug, it would be beneficial to add a comment explaining why this value is hardcoded. This will improve the long-term maintainability of the code.

Suggested change
min_seqlen_q=1,
# AITeR MHA kernel requires min_seqlen_q >= 1. Using
# `attn_metadata.prefill_metadata.min_query_len` can be 0,
# causing accuracy issues. See #28598.
min_seqlen_q=1,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a minor suggestion, maybe reorder_batch_threshold would be more appropriate here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sammysun0711 what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa, @ganyi1996ppo, thanks for your feedback.
For metadata preparation process, I checked metadata for prefill and extend, it seems reasonable:

For @ganyi1996ppo suggested reorder_batch_threshold, it is AiterFlashAttentionMetadata's variable with default int value 1 used for split_decodes_prefills_and_extends:

split_ret = split_decodes_prefills_and_extends(

It is not very intuitive to pass it min_seqlen_q, could you please elaborate more about this? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to reorder the request, for the request that query_len <= reorder_batch_threshold will be deemed as decode and be passed to the paged attention to process. So I think reorder_batch_threshold can be more intuitive to passed to the metadata rather than 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And of course 1 is correct enough for the MHA backend, since prefill or extend path will receive no request with query_len == 1, so, pass 1 directly is also LGTM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per offline sync with @ganyi1996ppo, we agree that:

  1. We can use min_seq_len=1 in this PR to fix accuracy issue, since currently Aiter MHA will not receive request with query_len == 1 for both prefill and extent path. And the usage is also consistent with code before PR25763 refactor.

  2. reorder_batch_threshold aim to address case3: UNIFORM: uniform multi-token queries mentioned in [ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763 (comment) (Not supported by Aiter attention backend yet), and its value may change during runtime for MTP case, so it will be better to leave it for the future work.

@tjtanaa, what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sammysun0711 Sure. Let's go with this approach. Thank you for the updates.

max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
min_seqlen_q=attn_metadata.extend_metadata.min_query_len,
min_seqlen_q=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the prefill case, this change correctly fixes the accuracy issue for the extend phase. Adding a comment here as well would be beneficial for future maintainability and to prevent regressions of this critical bug.

Suggested change
min_seqlen_q=1,
# AITeR MHA kernel requires min_seqlen_q >= 1. Using
# `attn_metadata.extend_metadata.min_query_len` can be 0,
# causing accuracy issues. See #28598.
min_seqlen_q=1,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be an issue with the metadata preparation steps?

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 13, 2025

@ganyi1996ppo @wuhuikx Could you take a look at this? Is this compatible with new AITER version as well? Thank you.

@sarckk
Copy link
Collaborator

sarckk commented Nov 13, 2025

+1 we've been seeing this issue and this PR fixes the accuracy issue we were observing

@ganyi1996ppo
Copy link
Contributor

LGTM

@sammysun0711
Copy link
Contributor Author

@ganyi1996ppo @wuhuikx Could you take a look at this? Is this compatible with new AITER version as well? Thank you.

It seems both AITER MHA function API and underlying CK pre-condition check does not change with new AITER version.
I checked this PR with AIER main commit (0.1.7.post2.dev8+g779c9f60a) on MI308X.

Here are Qwen/Qwen3-32B lm eval test with gsm8k (previous kernel built w/ old aiter version clean up before test)

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6285 ± 0.0133
strict-match 5 exact_match 0.7453 ± 0.0120

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 14, 2025
@tjtanaa tjtanaa enabled auto-merge (squash) November 14, 2025 17:07
@sarckk
Copy link
Collaborator

sarckk commented Nov 16, 2025

@sammysun0711 @tjtanaa could you rebase again?

…_attn_varlen_func for pure prefill and extend phase

Signed-off-by: Xiake Sun <xiake.sun@amd.com>
auto-merge was automatically disabled November 17, 2025 02:10

Head branch was pushed to by a user without write access

@sammysun0711 sammysun0711 force-pushed the fix_aiter_mha_min_seqlen_q branch from 5ead57c to 91bc7e8 Compare November 17, 2025 02:10
@22quinn 22quinn merged commit 60e089f into vllm-project:main Nov 17, 2025
46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants