Skip to content

[feat]: CUTLASS block scaled group gemm for SM100 #19757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 4, 2025

Conversation

djmmoss
Copy link
Contributor

@djmmoss djmmoss commented Jun 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR adds a CUTLASS block scaled group GEMM implementation for SM100. This is intended to be used as an alternative to DeepGEMM for Blackwell devices.

Test Plan

Adds unit for the new function: tests/kernels/moe/test_cutlass_grouped_gemm.py

Running the test:

python -m pytest tests/kernels/moe/test_cutlass_grouped_gemm.py 

Test Result

$ python -m pytest tests/kernels/moe/test_cutlass_grouped_gemm.py 
=============================================================================================================== test session starts ================================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.5.0
rootdir: /home/scratch.dmoss_gpu_1/repos/vllm
configfile: pyproject.toml
plugins: anyio-4.8.0, shard-0.1.2, xdoctest-1.0.2, flakefinder-1.1.0, xdist-3.6.1, hypothesis-6.127.9, rerunfailures-15.0, typeguard-4.3.0
collected 6 items                                                                                                                                                                                                                                  
Running 6 items in this shard

tests/kernels/moe/test_cutlass_grouped_gemm.py ......                                                                                                                                                                                        [100%]

================================================================================================================ 6 passed in 3.31s =================================================================================================================

(Optional) Documentation Update

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a specialized CUTLASS block-scaled grouped GEMM kernel optimized for NVIDIA's SM100 (Blackwell) architecture. This new implementation serves as an optimized alternative to existing DeepGEMM kernels, particularly for FP8 quantized Mixture-of-Experts (MoE) operations. The changes involve integrating the kernel into the build system, exposing it through the Python API, and updating the MoE layer dispatch logic to leverage this new, hardware-specific optimization, all controlled by a new environment variable.

Highlights

  • New Kernel Implementation: Introduces a new CUTLASS block-scaled grouped GEMM kernel specifically optimized for SM100 (Blackwell) architectures. This kernel is designed as an alternative to DeepGEMM for FP8 quantized Mixture-of-Experts (MoE) models, aiming for improved performance on the latest NVIDIA hardware.
  • Build System Integration: Updates CMakeLists.txt to conditionally build the new kernel. Compilation is enabled only when the CUDA compiler version is 12.8 or greater and the target architecture includes SM100, controlled by a new ENABLE_CUTLASS_MOE_SM100 flag.
  • Python API Exposure: The new cutlass_blockwise_scaled_grouped_mm operation is exposed to the Python API. This involves declaring the function in csrc/ops.h, binding it in csrc/torch_bindings.cpp, and providing a Python wrapper in vllm/_custom_ops.py.
  • MoE Layer Integration: The new kernel is integrated into the fused_experts logic within vllm/model_executor/layers/fused_moe/. It can be conditionally used for FP8 quantized MoE layers based on problem size alignment and a new environment variable, orchestrating input and intermediate FP8 quantization.
  • Feature Control: A new environment variable, VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM, has been added in vllm/envs.py. This allows users to explicitly enable or disable the use of this new, hardware-specific kernel.
  • Comprehensive Testing: A new unit test file, tests/kernels/moe/test_cutlass_grouped_gemm.py, has been added. This file includes helper functions for FP8 quantization and validates the correctness and accuracy of the new grouped GEMM kernel against a baseline across various problem sizes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the ci/build label Jun 17, 2025
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUTLASS block-scaled grouped GEMM kernel specifically for SM100 (Blackwell) architectures, intended as an alternative to DeepGEMM for FP8 MoE models. The changes include new CUDA C++ kernels, CMake build system updates, Python bindings, and corresponding tests.

Overall, the implementation looks promising. Key areas for attention include ensuring the CMake build correctly handles shared source files like moe_data.cu to prevent build failures, clarifying some conditions and messages, and addressing a potential issue in the test logic related to uninitialized tensor usage.

@@ -1182,6 +1187,18 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and N > 512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The N > 512 check for deep gemm was only for performance reasons. Does the triton kernel actually beat cutlass for N <= 512?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, at the moment cutlass is performing better

return True


def run_cutlass_block_scaled_fused_experts(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also integrate this with triton_deep_gemm_moe.py (maybe we should change this name) so it can be used with EP?

expert_offsets[:-1],
)

assert calc_diff(ref_out, out) < 1e-3, f"Cutlass grouped gemm is not accurate"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use torch.testing.assert_close here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, although given the sizes of the test there are some outliers compared to the FP32 baseline, this is why the atol is fairly lenient

vllm/envs.py Outdated
Comment on lines 814 to 817
# Allow use of Cutlass Blockwise Scaled Grouped GEMM kernels for fused moe ops.
"VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM":
lambda: bool(int(os.getenv(
"VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM", "0"))),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an environment variable? I would think we want to use it by default if it is available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to the default behavior

Comment on lines 1190 to 1201
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
assert apply_router_weight_on_input is False
return run_cutlass_block_scaled_fused_experts(
a=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense for the usage of CUTLASS to be decided on the quantized method level? For example, the same way CUTLASS MoE is picked in CompressedTensorsMoEMethod's get_moe_method() function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @ElizaWszola . I think it would be better to move the dispatching decisions at the MoeMethod level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR from @ElizaWszola does the integration already, I personally don't mind waiting for that PR to get merged in and then update this PR in a similar fashion to this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test for the full fused MoE operation?

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any benchmarks for these kernels that you could report @djmmoss?

typename LayoutSFA,
typename LayoutSFB,
typename ScaleConfig>
__global__ void get_ggemm_starts(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this do the same thing as __get_group_gemm_starts_blockscale_fp8 in #19983? Checking to see what we can consolidate between the two PRs

static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

using ArchTag = cutlass::arch::Sm100;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make ArchTag a template parameter? And then reuse this class for both SM90 and SM100?

@@ -393,6 +393,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag});
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);

// cutlass blockwise scaledgroup GEMM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// cutlass blockwise scaledgroup GEMM
// cutlass blockwise scaled group GEMM

@djmmoss
Copy link
Contributor Author

djmmoss commented Jun 25, 2025

@tlrmchlsmth I had a look over #19983, I could likely integrate the SM100 changes fairly simply if the CompressedTensorsMoEMethod's is prefer to the DeepGEMM style integration. (I personally have zero preference here 👍 )

In regard to performance, I'm mainly looking at GB200 DS-R1. For a single-node (TP4) you can expect roughly a 1.6x speed up in max-throughput and ~20% improvement in min-latency. For two-node (TP8) max-throughput improvement is around 1.4x with the same ~20% improvement in min-latency.

nv-dmoss added 10 commits June 25, 2025 10:05
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
@djmmoss djmmoss force-pushed the dmoss/blockscaled_cutlass_group_gemm branch from 1b21ec5 to 4adbbaa Compare June 25, 2025 17:06
@mgoin
Copy link
Member

mgoin commented Jun 26, 2025

FYI I include moe_data for sm100 here #20086

CMakeLists.txt Outdated
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "f115c3f85467d5d9619119d1dbeb9c03c3d73864" CACHE STRING "CUTLASS revision to use")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just needs CUTLASS 4.0 right? It would be nice to wait for the tag -- is there an estimated date for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tag was added last Friday, I've made the update 👍

@mgoin mgoin added this to the v0.9.2 milestone Jul 1, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
@djmmoss
Copy link
Contributor Author

djmmoss commented Jul 1, 2025

I'm also running lm_eval, I will post the result here, regarding 1 2 3 and 4 these are addressed in this PR I can help with the SM100 integration and testing when that one is ready to go.

CMakeLists.txt Outdated
@@ -296,6 +296,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is right, have you seen #20086 where I pulled moe_data.cu into it's own case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed that, I've fixed it now but there are some issues related to the cutlass v4.0.0 upgrade I'm working through

Copy link

mergify bot commented Jul 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @djmmoss.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 2, 2025
djmmoss added 2 commits July 2, 2025 18:44
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
@mergify mergify bot removed the needs-rebase label Jul 2, 2025
@djmmoss
Copy link
Contributor Author

djmmoss commented Jul 2, 2025

lm_eval results:

vllm (pretrained=/scratch/models/DeepSeek-R1,tensor_parallel_size=4,max_model_len=2048,gpu_memory_utilization=0.95,max_num_seqs=32,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9530|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

djmmoss added 2 commits July 2, 2025 21:49
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 3, 2025
@mgoin
Copy link
Member

mgoin commented Jul 3, 2025

hey @djmmoss I enabled the full CI and there are failures on the cutlass moe entrypoint https://buildkite.com/vllm/ci/builds/23192/steps/canvas?jid=0197d1b9-6de4-4bef-8c61-59fa059d4c44

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Copy link

mergify bot commented Jul 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @djmmoss.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@djmmoss
Copy link
Contributor Author

djmmoss commented Jul 3, 2025

I've pushed up the fix

@mergify mergify bot added the needs-rebase label Jul 3, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
@mergify mergify bot removed the needs-rebase label Jul 4, 2025
@mgoin mgoin merged commit 3d184b9 into vllm-project:main Jul 4, 2025
99 checks passed
@djmmoss djmmoss deleted the dmoss/blockscaled_cutlass_group_gemm branch July 4, 2025 19:05
sfeng33 pushed a commit to sfeng33/vllm that referenced this pull request Jul 6, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Duncan Moss <dmoss@nvidia.com>
huydhn pushed a commit to huydhn/vllm that referenced this pull request Jul 8, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Duncan Moss <dmoss@nvidia.com>
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Duncan Moss <dmoss@nvidia.com>
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Duncan Moss <dmoss@nvidia.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants