Skip to content

Triton-fused DeepseekScalingRotaryEmbedding #19771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ekuznetsov139
Copy link
Contributor

This PR uses Triton to fuse DeepseekScalingRotaryEmbedding operation with is_neox_style=False (observed in DeepSeek-V3). It substantially reduces the number of distinct kernels in DeepSeek-V2/V3 profile (the number of kernel launches per execution of DeepseekV2MLAAttention.forward is reduced approximately from 37 to 25).

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ekuznetsov139, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented a significant performance optimization for the DeepseekScalingRotaryEmbedding operation, particularly for DeepSeek-V2/V3 models. By leveraging Triton to fuse this operation, specifically for non-NeoX style rotary embeddings, I aim to reduce the overhead associated with multiple kernel launches, thereby improving the overall efficiency of attention mechanisms.

Highlights

  • Performance Enhancement: I've implemented a Triton-fused kernel for DeepseekScalingRotaryEmbedding when is_neox_style is False, targeting DeepSeek-V3 models.
  • Kernel Reduction: This change achieves a substantial reduction in the number of distinct kernel launches, from approximately 37 to 25, during the execution of DeepseekV2MLAAttention.forward.
  • Triton Kernel Integration: I've introduced deepseek_scaling_rotary_emb_kernel_gptj, a new Triton kernel, to perform the fused rotary embedding computation for query and key tensors.
  • Conditional Execution: The Triton-fused path is conditionally applied only when running on a CUDA device and for non-NeoX style rotary embeddings, ensuring compatibility and falling back to the original PyTorch implementation otherwise.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ekuznetsov139 ekuznetsov139 force-pushed the deepseek_triton_rotary_embedding branch from 3042f1e to 59cd1b8 Compare June 17, 2025 21:00
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a Triton-fused kernel for DeepseekScalingRotaryEmbedding to improve performance. The changes involve adding a new Triton kernel and modifying the forward pass. Key suggestions include adding tests, improving kernel clarity, and addressing production considerations.

Comment on lines 763 to 770
def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q,
stride1: int,
stride2: int,
stride_cs: int,
dim1: int,
dim2: int,
dim3: int,
BLOCK_SIZE: tl.constexpr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a docstring to the deepseek_scaling_rotary_emb_kernel_gptj function to explain its purpose, arguments, and any specific implementation details. This will improve readability and maintainability.

dim2=q.shape[1],
dim3=q.shape[2]//2,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_warps parameter is hardcoded to 1. Consider benchmarking different values to find the optimal setting for various hardware configurations. Add a comment explaining the choice if num_warps=1 is indeed optimal.

@ekuznetsov139 ekuznetsov139 force-pushed the deepseek_triton_rotary_embedding branch 3 times, most recently from 1a65790 to 90059dc Compare June 20, 2025 13:43
Signed-off-by: Eugene Kuznetsov <eugene.kuznetsov@amd.com>
@ekuznetsov139 ekuznetsov139 force-pushed the deepseek_triton_rotary_embedding branch from 90059dc to 0c3e746 Compare June 20, 2025 14:47
@mergify mergify bot added the deepseek Related to DeepSeek models label Jul 2, 2025
@ekuznetsov139
Copy link
Contributor Author

Bump...

@simon-mo
Copy link
Collaborator

simon-mo commented Aug 8, 2025

can this just be done/generated by torch compile?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants