Skip to content

Conversation

Daisy-Ma-coder
Copy link
Contributor

@Daisy-Ma-coder Daisy-Ma-coder commented Sep 19, 2025

Purpose

Add VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH in env variables so users have control over cuda graph max_num_splits in cli level.

When applying #23958, realized the _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH value is copied from Flash_attn (code ref) and this mentioned to be tuned if needed. Thinking we should surface this to front end.

Test Plan

Tested based off docker image vllm/vllm-openai:v0.10.2 with this pr

VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH=64 VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA vllm serve deepseek-ai/DeepSeek-V3 \
    --port 3000 \
    --tensor-parallel-size 8 \
    --max-model-len 32768 \
    --max-num-seqs 8 \
    --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'

Test Result

max_num_splits = default 16

============ Serving Benchmark Result ============
Successful requests:                     20        
Benchmark duration (s):                  6.54      
Total input tokens:                      634860    
Total generated tokens:                  938       
Request throughput (req/s):              3.06      
Output token throughput (tok/s):         143.37    
Total Token throughput (tok/s):          97181.20  
---------------Time to First Token----------------
Mean TTFT (ms):                          1734.83   
Median TTFT (ms):                        1381.53   
P99 TTFT (ms):                           4708.08   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          61.00     
Median TPOT (ms):                        34.75     
P99 TPOT (ms):                           136.90    
---------------Inter-token Latency----------------
Mean ITL (ms):                           32.47     
Median ITL (ms):                         24.77     
P99 ITL (ms):                            138.39    
==================================================

max_num_splits = 32


============ Serving Benchmark Result ============
Successful requests:                     20        
Benchmark duration (s):                  6.09      
Total input tokens:                      634860    
Total generated tokens:                  1124      
Request throughput (req/s):              3.28      
Output token throughput (tok/s):         184.58    
Total Token throughput (tok/s):          104439.30 
---------------Time to First Token----------------
Mean TTFT (ms):                          1791.15   
Median TTFT (ms):                        982.61    
P99 TTFT (ms):                           4108.43   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          54.94     
Median TPOT (ms):                        32.48     
P99 TPOT (ms):                           156.46    
---------------Inter-token Latency----------------
Mean ITL (ms):                           29.34     
Median ITL (ms):                         24.66     
P99 ITL (ms):                            146.61    
==================================================

max_num_splits = 64

============ Serving Benchmark Result ============
Successful requests:                     20        
Benchmark duration (s):                  7.37      
Total input tokens:                      634860    
Total generated tokens:                  1133      
Request throughput (req/s):              2.71      
Output token throughput (tok/s):         153.73    
Total Token throughput (tok/s):          86293.77  
---------------Time to First Token----------------
Mean TTFT (ms):                          2379.76   
Median TTFT (ms):                        1592.55   
P99 TTFT (ms):                           5235.32   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.26     
Median TPOT (ms):                        30.98     
P99 TPOT (ms):                           178.50    
---------------Inter-token Latency----------------
Mean ITL (ms):                           32.38     
Median ITL (ms):                         24.71     
P99 ITL (ms):                            169.26    
==================================================

Quality check

pip install lm_eval  # inside docker container

Command:

VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 64 VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA lm_eval \
  --model vllm \
  --model_args '{
    "pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat",
    "tensor_parallel_size": 8,
    "dtype": "auto",
    "gpu_memory_utilization": 0.9,
    "trust_remote_code": true,
    "max_model_len": 16384,
    "compilation_config": {
      "cudagraph_mode": "FULL_DECODE_ONLY"
    }
  }' \
  --task gsm8k \
  --num_fewshot 5 \
  --batch_size auto
vllm ({'pretrained': 'deepseek-ai/DeepSeek-V2-Lite-Chat', 'tensor_parallel_size': 8, 'dtype': 'auto', 'gpu_memory_utilization': 0.9, 'trust_remote_code': True, 'max_model_len': 16384, 'compilation_config': {'cudagraph_mode': 'FULL_DECODE_ONLY'}}), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6649|±  |0.0130|
|     |       |strict-match    |     5|exact_match|↑  |0.6535|±  |0.0131|

Flash Attention Quality Check on Mixtral-8x7B

lm_eval \
  --model vllm \
  --model_args '{
    "pretrained": "RedHatAI/Mixtral-8x7B-Instruct-v0.1",
    "tensor_parallel_size": 8,
    "dtype": "auto",
    "gpu_memory_utilization": 0.9,
    "trust_remote_code": true,
    "max_model_len": 16384,
    "compilation_config": {
      "cudagraph_mode": "FULL"
    }
  }' \
  --task gsm8k \
  --num_fewshot 5 \
  --batch_size auto

vllm ({'pretrained': 'RedHatAI/Mixtral-8x7B-Instruct-v0.1', 'tensor_parallel_size': 8, 'dtype': 'auto', 'gpu_memory_utilization': 0.9, 'trust_remote_code': True, 'max_model_len': 16384, 'compilation_config': {'cudagraph_mode': 'FULL'}}), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6406|±  |0.0132|
|     |       |strict-match    |     5|exact_match|↑  |0.6384|±  |0.0132|

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Sep 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new environment variable VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH to allow users to configure the max_num_splits for FlashAttention with CUDA graphs. The changes correctly add the environment variable definition, parsing logic, and integrate it into the attention backend. However, there is a logical flaw in how the environment variable is defined and consumed. The current implementation in vllm/envs.py causes the check in vllm/v1/attention/backends/mla/flashattn_mla.py to always be true, leading to dead code. My review provides suggestions to align the implementation with the existing pattern for optional integer environment variables in the codebase, which will fix the logical issue and improve code clarity and consistency.

@@ -118,6 +118,7 @@
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency with other optional integer environment variables like VLLM_FLASH_ATTN_VERSION, it's better to define this as Optional[int] and handle the default value in the consumer module (flashattn_mla.py). This makes the intent clearer that the variable is optional and has a fallback. This change is related to another suggested change for the lambda function of this environment variable.

Suggested change
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: Optional[int] = None

vllm/envs.py Outdated
Comment on lines 955 to 956
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "16")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To allow checking if the environment variable was explicitly set, this lambda should return None when the variable is not present. The current implementation always returns an integer, which causes a logical flaw in flashattn_mla.py. Using maybe_convert_int without a default for os.getenv is the standard pattern in this file for optional integer variables like VLLM_FLASH_ATTN_VERSION.

Suggested change
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "16")),
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: maybe_convert_int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH")),

Comment on lines 101 to 107
if envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH is not None:
logger.info_once("Getting flash attention max num splits for "
"cuda graph from environment variable, value=%s",
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
else:
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a logical flaw here. The if condition envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH is not None will always evaluate to True because of how the environment variable is defined in vllm/envs.py. The lambda function for it always returns an integer (defaulting to 16 if not set), never None. This makes the else block unreachable (dead code).

To fix this, you should modify vllm/envs.py to follow the pattern of other optional integer environment variables. Specifically:

  1. Change the type hint to Optional[int] = None.
  2. Change the lambda to use maybe_convert_int(os.getenv(...)) without a default, so it returns None if the variable is not set.

With those changes in vllm/envs.py, this block of code will work as intended. I've added separate comments in vllm/envs.py with the specific suggestions.

@MatthewBonanni
Copy link
Contributor

MatthewBonanni commented Sep 19, 2025

Thanks for the contribution! Could you also update the non-MLA flash attention backend to use this env var? Regarding gemini's comments, I think you can get rid of _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH entirely and let the env var manage the default

@Daisy-Ma-coder
Copy link
Contributor Author

Thanks for the contribution! Could you also update the non-MLA flash attention backend to use this env var? Regarding gemini's comments, I think you can get rid of _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH entirely and let the env var manage the default

thanks Matt, updated.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the contribution!

qqma and others added 8 commits September 20, 2025 13:58
… users have control over cuda graph max_num_splits in cli level

Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: qqma <qqma@amazon.com>
…LA flash attention

Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: qqma <qqma@amazon.com>
Copy link

mergify bot commented Sep 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Daisy-Ma-coder.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Sep 20, 2025
@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 20, 2025
Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the api process count and rank related to this PR? seems like a bad merge

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh seems like this merged pr https://github.com/vllm-project/vllm/pull/23717/files is included in mine somehow. Let me try to fix it.

@github-project-automation github-project-automation bot moved this from To Triage to In progress in gpt-oss Issues & Enhancements Sep 20, 2025
…t#23717)"

This reverts commit 6e64b12.

Signed-off-by: qqma <qqma@amazon.com>
qqma added 2 commits September 20, 2025 18:33
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: qqma <qqma@amazon.com>
@github-project-automation github-project-automation bot moved this from In progress to Ready in gpt-oss Issues & Enhancements Sep 22, 2025
@simon-mo simon-mo merged commit cfbee3d into vllm-project:main Sep 22, 2025
42 checks passed
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…v variables (vllm-project#25274)

Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: qqma <qqma@amazon.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…v variables (vllm-project#25274)

Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: qqma <qqma@amazon.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: charlifu <charlifu@amd.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…v variables (#25274)

Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: qqma <qqma@amazon.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector multi-modality Related to multi-modality (#4194) performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed structured-output v1
Projects
Status: Done
Status: Done
Development

Successfully merging this pull request may close these issues.

5 participants