Skip to content

[Perf] Optimize moe_align_block_size CUDA kernel #19572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

yewentao256
Copy link
Contributor

@yewentao256 yewentao256 commented Jun 12, 2025

Purpose

Fixes #19517

The implementation is taken from https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc

Specially thanks to SGL developers!

Notes:

  • Only one moe_align_block_size kernel now.
  • Mark moe_align_block_size_triton and sgl_moe_align_block_size for deprecation

Test

Tested on B200

Unit test

截屏2025-06-12 下午2 05 36

Benchmark (with Triton)

num_tokens num_experts topk VLLM(origin) VLLM(now) Triton
1.0 16.0 1.0 20.3 22.0 47.7
1.0 16.0 2.0 20.2 22.2 47.6
1.0 16.0 8.0 20.2 22.1 47.4
1.0 64.0 1.0 25.8 27.7 47.1
1.0 64.0 2.0 25.6 27.7 47.7
1.0 64.0 8.0 26.5 27.8 46.4
1.0 224.0 1.0 62.4 32.7 54.3
1.0 224.0 2.0 62.4 32.6 53.5
1.0 224.0 8.0 62.4 32.6 54.1
1.0 256.0 1.0 71.7 31.8 58.4
1.0 256.0 2.0 71.8 31.8 58.4
1.0 256.0 8.0 71.8 31.8 58.4
1.0 280.0 1.0 45.3 32.3 62.5
1.0 280.0 2.0 45.2 32.0 62.5
1.0 280.0 8.0 45.2 32.3 63.5
1.0 512.0 1.0 317.8 39.8 87.1
1.0 512.0 2.0 317.7 40.4 87.1
1.0 512.0 8.0 317.7 40.2 87.1
16.0 16.0 1.0 21.4 23.6 46.2
16.0 16.0 2.0 21.3 23.6 46.5
16.0 16.0 8.0 21.4 24.2 47.3
16.0 64.0 1.0 25.6 28.4 46.8
16.0 64.0 2.0 25.6 28.4 46.9
16.0 64.0 8.0 26.4 29.7 45.6
16.0 224.0 1.0 62.4 32.8 54.3
16.0 224.0 2.0 62.4 32.6 54.3
16.0 224.0 8.0 62.4 32.7 56.4
16.0 256.0 1.0 72.7 34.4 58.3
16.0 256.0 2.0 72.7 33.9 58.3
16.0 256.0 8.0 72.7 34.2 58.3
16.0 280.0 1.0 46.0 33.6 60.4
16.0 280.0 2.0 46.0 33.8 60.4
16.0 280.0 8.0 46.0 33.9 60.5
16.0 512.0 1.0 318.5 39.8 87.1
16.0 512.0 2.0 318.5 40.0 87.1
16.0 512.0 8.0 318.7 40.0 87.1
256.0 16.0 1.0 19.5 24.2 47.0
256.0 16.0 2.0 22.0 28.7 47.3
256.0 16.0 8.0 33.7 28.2 58.4
256.0 64.0 1.0 27.5 30.3 47.1
256.0 64.0 2.0 27.6 31.7 88.4
256.0 64.0 8.0 31.9 26.2 47.9
256.0 224.0 1.0 62.4 32.8 56.4
256.0 224.0 2.0 62.4 31.8 56.4
256.0 224.0 8.0 62.6 30.5 60.5
256.0 256.0 1.0 70.8 32.4 55.3
256.0 256.0 2.0 72.7 33.6 58.4
256.0 256.0 8.0 74.7 33.9 60.4
256.0 280.0 1.0 44.2 33.7 60.4
256.0 280.0 2.0 44.6 33.8 60.4
256.0 280.0 8.0 48.1 35.0 64.3
256.0 512.0 1.0 318.3 40.5 87.0
256.0 512.0 2.0 318.5 40.5 87.0
256.0 512.0 8.0 323.7 41.1 88.2
4096.0 16.0 1.0 48.1 27.8 78.8
4096.0 16.0 2.0 74.9 28.2 122.0
4096.0 16.0 8.0 261.1 44.6 381.0
4096.0 64.0 1.0 43.5 28.3 52.2
4096.0 64.0 2.0 66.6 33.4 62.4
4096.0 64.0 8.0 192.9 48.1 130.1
4096.0 224.0 1.0 68.6 33.8 64.6
4096.0 224.0 2.0 80.9 39.0 77.8
4096.0 224.0 8.0 167.7 54.3 99.5
4096.0 256.0 1.0 76.8 33.9 70.6
4096.0 256.0 2.0 85.7 38.0 76.7
4096.0 256.0 8.0 131.2 52.3 101.4
4096.0 280.0 1.0 50.1 36.7 70.7
4096.0 280.0 2.0 53.4 37.8 80.8
4096.0 280.0 8.0 104.1 52.3 105.4
4096.0 512.0 1.0 328.8 42.0 85.0
4096.0 512.0 2.0 345.1 44.6 95.2
4096.0 512.0 8.0 448.1 58.4 122.8

End to end Throughput

Throughput(fp16)

vllm bench throughput --model Qwen/Qwen3-30B-A3B --load-format dummy --input-len 1000 --output-len 100

Throughput: 46.03 requests/s, 50547.85 total tokens/s, 4603.43 output tokens/s(B200)
Throughput: 47.63 requests/s, 52312.68 total tokens/s, 4762.52 output tokens/s(B200 now)

Throughput(fp8)

vllm bench throughput --model Qwen/Qwen3-30B-A3B-FP8 --load-format dummy --input-len 1000 --output-len 100

Throughput: 42.28 requests/s, 46424.61 total tokens/s, 4228.17 output tokens/s(B200)
Throughput: 44.17 requests/s, 48497.60 total tokens/s, 4417.43 output tokens/s (B200 now)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yewentao256, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request optimizes the moe_align_block_size CUDA kernel, a critical component in Mixture-of-Experts (MoE) models. It replaces disparate implementations with a unified, high-performance C++ kernel, leading to notable speedups for configurations with many experts. The change also simplifies the codebase by deprecating older kernels and includes new tests and benchmarks to validate the improvements.

Highlights

  • Kernel Unification: Consolidated multiple CUDA kernel implementations for moe_align_block_size into a single, optimized C++ implementation. The previous moe_align_block_size_kernel, moe_align_block_size_global_mem_kernel, and sgl_moe_align_block_size_kernel have been replaced or refactored.
  • Performance Improvement: The new C++ kernel shows significant performance improvements, particularly for larger numbers of experts (e.g., 224, 256, 512), as demonstrated by the provided benchmark results on B200.
  • Deprecation: The Python moe_align_block_size_triton and the C++ sgl_moe_align_block_size functions are marked for deprecation, simplifying the codebase.
  • Testing and Benchmarking: Added new unit tests and a benchmark script specifically for moe_align_block_size to verify correctness and measure performance against the Triton implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors and optimizes the moe_align_block_size CUDA kernel. Key changes include:

  • Kernel Consolidation: The existing sgl_moe_align_block_size_kernel is enhanced and becomes the primary CUDA kernel, while older shared-memory and global-memory specific versions are removed. A new specialized kernel (moe_align_block_size_small_batch_expert_kernel) is introduced for small batch/expert scenarios.
  • Unified C++ API: The host-side C++ function moe_align_block_size now intelligently dispatches to either the small-batch kernel or the main two-kernel path (alignment count + sort).
  • Python API Simplification: The Python wrapper in vllm.model_executor.layers.fused_moe.moe_align_block_size is simplified, removing complex dispatch logic and relying on the C++ backend.
  • Deprecation: Older/redundant functions like sgl_moe_align_block_size (Python and C++) and moe_align_block_size_triton are marked for deprecation.
  • Performance: Benchmark results indicate substantial performance improvements with the new implementation.
  • Testing: New benchmark and unit test files are added to verify correctness and performance against a Triton implementation.

The changes appear well-structured and the CUDA kernel logic follows established parallel programming patterns. The simplification of the Python-level dispatch is a welcome improvement for maintainability.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 force-pushed the wye-optimize-moe_align_block_size branch from 25d764d to b6afb8c Compare June 12, 2025 18:25
@yewentao256
Copy link
Contributor Author

@mgoin Please take a look

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good! I would like an eval for Qwen3 and an end-to-end benchmark on DeepSeek (dummy load is ok) to make sure everything is in order, in addition to CI. Then I think this is good to go

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256
Copy link
Contributor Author

Looks pretty good! I would like an eval for Qwen3 and an end-to-end benchmark on DeepSeek (dummy load is ok) to make sure everything is in order, in addition to CI. Then I think this is good to go

This is what I got:

lm_eval --model vllm --model_args pretrained=Qwen/Qwen3-30B-A3B-FP8,max_model_len=32768 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Origin
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8264|±  |0.0104|
|     |       |strict-match    |     5|exact_match||0.8923|±  |0.0085|

Now
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.8378|±  |0.0102|
|     |       |strict-match    |     5|exact_match||0.8916|±  |0.0086|


vllm bench throughput --model deepseek-ai/deepseek-llm-7b-base --load-format dummy --input-len 1000 --output-len 100

Origin
Throughput: 45.27 requests/s, 49710.98 total tokens/s, 4527.47 output tokens/s
Throughput: 45.22 requests/s, 49653.42 total tokens/s, 4522.34 output tokens/s

Now

Throughput: 45.07 requests/s, 49545.56 total tokens/s, 4506.91 output tokens/s
Throughput: 45.57 requests/s, 50064.75 total tokens/s, 4557.25 output tokens/s

vllm bench throughput --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --load-format dummy --input-len 1000 --output-len 100

origin
Throughput: 6.47 requests/s, 7110.63 total tokens/s, 647.04 output tokens/s

Now
Throughput: 6.51 requests/s, 7158.93 total tokens/s, 651.48 output tokens/s

@yewentao256
Copy link
Contributor Author

yewentao256 commented Jun 13, 2025

More benchmark test added:

# E=64 (improvement)
vllm bench throughput --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --load-format dummy --input-len 1000 --output-len 100 --trust_remote_code
Origin
Throughput: 54.46 requests/s, 59832.15 total tokens/s, 5445.72 output tokens/s
Now
Throughput: 58.19 requests/s, 63906.07 total tokens/s, 5819.48 output tokens/s

# This should be the same since E=256
vllm bench throughput --model RedHatAI/DeepSeek-R1-0528-quantized.w4a16 --load-format dummy --input-len 1000 --output-len 100 -tp 4
Origin 
Throughput: 12.22 requests/s, 13414.60 total tokens/s, 1221.75 output tokens/s
Now
Throughput: 12.24 requests/s, 13452.07 total tokens/s, 1224.20 output tokens/s

Comment on lines 16 to 29
itertools.product(
[32, 64, 128, 256], # block_size
[
1,
4,
16,
64,
256,
1024,
4096,
], # num_tokens
[1, 4, 16, 64], # topk
[64, 160, 256, 257, 260, 264], # num_experts
)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long do these tests take? Should we explicitly list out the problem sizes to test?

And it would be very good to add some tests for non-power-of-two num_tokens, including some odd problem sizes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

微信截图_20250613182506

Added the odd problem sizes, roughly took 2 minutes now.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Jun 14, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thank you for your investigation

@mgoin mgoin enabled auto-merge (squash) June 14, 2025 20:51
@simon-mo simon-mo disabled auto-merge June 17, 2025 18:49
@simon-mo simon-mo merged commit ffb2cd6 into vllm-project:main Jun 17, 2025
90 of 97 checks passed
@mgoin mgoin deleted the wye-optimize-moe_align_block_size branch June 19, 2025 08:06
yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jun 22, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Optimize moe_align_block_size CUDA kernel
4 participants