Skip to content

Conversation

Rohan138
Copy link
Contributor

@Rohan138 Rohan138 commented Sep 17, 2025

Purpose

AiterFlashAttentionBackend (enabled through VLLM_ROCM_USE_AITER_MHA=1) does not work correctly with cudagraph_mode = FULL, since it calls two AITER kernels for prefill (aiter.flash_attn_varlen_func) and decode (aiter.paged_attention_v1). The correct AttentionCGSupport and cudagraph_mode to use here are AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE and FULL_AND_PIECEWISE respectively.

Test Plan

For e.g. Llama 3.3 70B FP8 TP1 max_conc=256 on MI355, this fixes both OOMs as well as some accuracy issues we've been seeing.

Test Result

Serving command:

vllm serve amd/Llama-3.1-70B-Instruct-FP8-KV -tp 1 --max-num-batched-tokens 131072 --max-num-seqs 1024 --max-seq-len-to-capture 16384 --max-model-len 10240 --swap-space 64 --no-enable-prefix-caching --disable-log-requests --disable-uvicorn-access-log --trust-remote-code --gpu-memory-utilization 0.90 -O '{"compilation_config":"FULL"}'

CG capture:
image

LM-Eval:

lm_eval --model local-completions --model_args model=$MODEL,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=256,max_retries=10,max_gen_toks=2048 --batch_size auto --tasks gsm8k --num_fewshot 5 --limit 250  --output_path . --apply_chat_template 2>&1 | tee -a eval.log
INFO 09-17 22:05:58 [__init__.py:241] Automatically detected platform rocm.
2025-09-17:22:06:00 WARNING  [__main__:369]  --limit SHOULD ONLY BE USED FOR TESTING.REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.
2025-09-17:22:06:00 INFO     [__main__:446] Selected Tasks: ['gsm8k']
2025-09-17:22:06:00 INFO     [evaluator:202] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2025-09-17:22:06:00 INFO     [evaluator:240] Initializing local-completions model, with arguments: {'model': 'amd/Llama-3.1-70B-Instruct-FP8-KV', 'base_url':
        'http://0.0.0.0:8000/v1/completions', 'num_concurrent': 256, 'max_retries': 10, 'max_gen_toks': 2048}
2025-09-17:22:06:00 WARNING  [models.api_models:158] Automatic batch size is not supported for API models. Defaulting to batch size 1.
2025-09-17:22:06:00 INFO     [models.api_models:170] Using max length 2048 - 1
2025-09-17:22:06:00 INFO     [models.api_models:189] Using tokenizer huggingface
2025-09-17:22:06:04 INFO     [evaluator:305] gsm8k: Using gen_kwargs: {'until': ['Question:', '</s>', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2025-09-17:22:06:04 WARNING  [evaluator:324] Overwriting default num_fewshot of gsm8k from 5 to 5
2025-09-17:22:06:04 WARNING  [evaluator:480] Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details.
2025-09-17:22:06:04 INFO     [api.task:434] Building contexts for gsm8k on rank 0...
100%|██████████| 250/250 [00:00<00:00, 754.99it/s]
2025-09-17:22:06:04 INFO     [evaluator:574] Running generate_until requests
2025-09-17:22:06:05 WARNING  [models.api_models:756] Some contexts exceeded (max length: (2047) - max_gen_toks (2048). They were left truncated.
Requesting API: 100%|██████████| 250/250 [01:02<00:00,  4.03it/s]
fatal: detected dubious ownership in repository at '/home/ropotdar/Desktop/MAD-private'
To add an exception for this directory, call:

        git config --global --add safe.directory /home/ropotdar/Desktop/MAD-private
2025-09-17:22:07:07 INFO     [loggers.evaluation_tracker:209] Saving results aggregated
local-completions (model=amd/Llama-3.1-70B-Instruct-FP8-KV,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=256,max_retries=10,max_gen_toks=2048), gen_kwargs: (None), limit: 250.0, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.854|±  |0.0252|
|     |       |strict-match    |     5|exact_match|↑  |0.208|±  |0.0257|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@Rohan138 Rohan138 requested a review from gshtras as a code owner September 17, 2025 22:11
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Sep 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly changes the cudagraph_support for the AiterFlashAttention backend to UNIFORM_SINGLE_TOKEN_DECODE. This is a necessary fix for using CUDA graphs, as this backend uses different kernels for prefill and decode operations. However, I've identified a critical logic bug in the AiterFlashAttentionImpl.forward method where both prefill and decode kernels are called for prefill requests, leading to incorrect outputs. I've added a detailed comment on this issue. While the change in this PR is a valid step, addressing the underlying bug is crucial for the backend's correctness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change to UNIFORM_SINGLE_TOKEN_DECODE is a correct step to fix CUDA graph compilation, as this backend uses different kernels for prefill and decode.

However, a more critical issue exists in the AiterFlashAttentionImpl.forward method. For batches containing prefill tokens (max_seqlen_q > 1), the implementation calls flash_attn_varlen_func and then unconditionally proceeds to call paged_attention_v1. This results in the output from the prefill kernel being overwritten, leading to incorrect attention results for prefill sequences.

To ensure correctness, the prefill and decode paths should be mutually exclusive. This can be achieved by restructuring the logic with an if/else statement.

Here is a suggested structure for the fix in AiterFlashAttentionImpl.forward:

        if max_seqlen_q > 1:
            # Prefill path
            torch.ops.vllm.flash_attn_varlen_func(
                query[:num_actual_tokens],
                key_cache,
                value_cache,
                out=output[:num_actual_tokens],
                # ... other args
            )
        else:
            # Decode path
            # ... prepare workspace_buffer
            torch.ops.aiter.paged_attention_v1(
                output[:num_actual_tokens],
                workspace_buffer,
                query[:num_actual_tokens],
                # ... other args
            )
        return output

Fixing this underlying logic bug is essential for the overall correctness of the AiterFlashAttentionBackend.

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
@Rohan138 Rohan138 force-pushed the aiter_fa_piecewise_prefill branch from e6bca59 to 891b519 Compare September 17, 2025 22:17
@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 18, 2025
@gshtras gshtras enabled auto-merge (squash) September 18, 2025 17:10
@gshtras gshtras merged commit bbdc0f2 into vllm-project:main Sep 18, 2025
53 checks passed
@Rohan138 Rohan138 deleted the aiter_fa_piecewise_prefill branch September 18, 2025 20:00
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
…ion (vllm-project#25104)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…ion (vllm-project#25104)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…ion (vllm-project#25104)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants