Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Sep 25, 2025

Purpose

By looking at get_best_config in DeepGEMM's https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp we realized that we can greatly reduce the number of batch sizes we have to run for each weight shape when precompiling DeepGEMM. This greatly improves startup time when hitting warm DeepGEMM cache because we need to run many fewer gemms.

Test Plan

Test that all the kernels necessary are generated before and after this PR

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a heuristic-based approach to speed up the DeepGEMM warmup process. By selectively choosing M values for GEMM operations instead of iterating through all possibilities, it significantly reduces the number of warmup iterations. The changes look good, but there is a significant issue with hardcoding device 0 when querying for device properties, which could lead to incorrect warmup behavior in multi-GPU environments. I've added comments with suggestions to address this.

@varun-sundar-rabindranath
Copy link
Contributor

varun-sundar-rabindranath commented Sep 25, 2025

Nice changes @mgoin!

I verified with command
vllm bench throughput --model deepseek-ai/DeepSeek-V3 --hf-overrides.num_hidden_layers=4 --load-format=dummy --enforce-eager --input-len=10000 --output-len=500 --num-prompts=1
that the kernels generated are the same between main and this PR.

Verified the same for Qwen/Qwen3-30B-A3B-FP8

@varun-sundar-rabindranath
Copy link
Contributor

Like we talked over chat - I think the default should be processing all the token sizes so we can be sure that in deployment scenarios we aren't jitting in the hotpath. We can introduce an env var VLLM_RELAX_DEEP_GEMM_WARMUP and hide _generate_optimal_warmup_m_values behind that (like something people can opt-in).
Thanks 🙌

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @mgoin

@varun-sundar-rabindranath
Copy link
Contributor

@tlrmchlsmth @yewentao256 can you please take a look when you find some time. Thanks 🙌

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

I am thinking if we add too many env variables and make it complicated to use.
Would full warm up at the first time and we automatically skip second warm up if cache is there work?

Comment on lines +39 to +42
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from vllm.utils import cdiv

@varun-sundar-rabindranath
Copy link
Contributor

Would full warm up at the first time and we automatically skip second warm up if cache is there work?

@yewentao256 We considered this and decided against it for a first pass as it is hard to detect if the cache has a JIT for a particular GEMM shape.

  1. About the detection logic itself, it is insufficient to use just the model name as key, we would also need to consider parallelization configs to generate a reasonable key. This is can get unruly. [we could probably reuse the hash that we generate for torch compile - I thought of this as I was typing and could be useful]
  2. Updates in deep_gemm: If deep_gemm updates and we skip warmup due to some local detection logic, we will run with old kernels. [Maybe we can add the deep_gemm commit hash to torch compile hash]

I am thinking of VLLM_RELAX_DEEP_GEMM_WARMUP as an env var for power users. For normal users, I think the default recommendation should be to use neither VLLM_RELAX_DEEP_GEMM_WARMUP nor VLLM_SKIP_DEEP_GEMM_WARMUP as it guarantees optimal performance.

vllm/envs.py Outdated
Comment on lines 1149 to 1083
"VLLM_RELAX_DEEP_GEMM_WARMUP":
lambda: bool(int(os.getenv("VLLM_RELAX_DEEP_GEMM_WARMUP", "0"))),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it too dangerous to default to relaxed?

mgoin and others added 5 commits October 10, 2025 16:27
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed deepseek Related to DeepSeek models labels Oct 11, 2025
@vllm-bot vllm-bot merged commit 0d21b9b into vllm-project:main Oct 13, 2025
46 of 48 checks passed
VladOS95-cyber pushed a commit to VladOS95-cyber/vllm that referenced this pull request Oct 13, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Vladislav <vladislav.bronzov@gmail.com>
1994 pushed a commit to 1994/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: 1994 <1994@users.noreply.github.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Nov 12, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants