-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[ROCm][AITER][Bugfix] Switch AITER to use PIECEWISE_AND_FULL compilation #25104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly changes the cudagraph_support
for the AiterFlashAttention backend to UNIFORM_SINGLE_TOKEN_DECODE
. This is a necessary fix for using CUDA graphs, as this backend uses different kernels for prefill and decode operations. However, I've identified a critical logic bug in the AiterFlashAttentionImpl.forward
method where both prefill and decode kernels are called for prefill requests, leading to incorrect outputs. I've added a detailed comment on this issue. While the change in this PR is a valid step, addressing the underlying bug is crucial for the backend's correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change to UNIFORM_SINGLE_TOKEN_DECODE
is a correct step to fix CUDA graph compilation, as this backend uses different kernels for prefill and decode.
However, a more critical issue exists in the AiterFlashAttentionImpl.forward
method. For batches containing prefill tokens (max_seqlen_q > 1
), the implementation calls flash_attn_varlen_func
and then unconditionally proceeds to call paged_attention_v1
. This results in the output from the prefill kernel being overwritten, leading to incorrect attention results for prefill sequences.
To ensure correctness, the prefill and decode paths should be mutually exclusive. This can be achieved by restructuring the logic with an if/else
statement.
Here is a suggested structure for the fix in AiterFlashAttentionImpl.forward
:
if max_seqlen_q > 1:
# Prefill path
torch.ops.vllm.flash_attn_varlen_func(
query[:num_actual_tokens],
key_cache,
value_cache,
out=output[:num_actual_tokens],
# ... other args
)
else:
# Decode path
# ... prepare workspace_buffer
torch.ops.aiter.paged_attention_v1(
output[:num_actual_tokens],
workspace_buffer,
query[:num_actual_tokens],
# ... other args
)
return output
Fixing this underlying logic bug is essential for the overall correctness of the AiterFlashAttentionBackend
.
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
e6bca59
to
891b519
Compare
…ion (vllm-project#25104) Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
…ion (vllm-project#25104) Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
…ion (vllm-project#25104) Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: charlifu <charlifu@amd.com>
Purpose
AiterFlashAttentionBackend (enabled through VLLM_ROCM_USE_AITER_MHA=1) does not work correctly with cudagraph_mode = FULL, since it calls two AITER kernels for prefill (aiter.flash_attn_varlen_func) and decode (aiter.paged_attention_v1). The correct AttentionCGSupport and cudagraph_mode to use here are
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
andFULL_AND_PIECEWISE
respectively.Test Plan
For e.g. Llama 3.3 70B FP8 TP1 max_conc=256 on MI355, this fixes both OOMs as well as some accuracy issues we've been seeing.
Test Result
Serving command:
CG capture:

LM-Eval:
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.