Skip to content

Conversation

charlifu
Copy link
Contributor

@charlifu charlifu commented Sep 25, 2025

This PR adds a few fusion pass for Aiter to fusion layernorm + fp8 block quant and silu + fp8 block quant.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for fusing layernorm and silu operations with FP8 block quantization for AITemplate on ROCm. The changes introduce new fusion patterns for torch.compile and the corresponding custom operator implementations. The changes also expand AITemplate support to non-MI300 ROCm devices by providing a Triton-based fallback for GEMM operations.

My review found a critical issue in the type hints for the newly added custom operator implementations in vllm/model_executor/layers/layernorm.py. The residual parameter is incorrectly typed as torch.Tensor instead of Optional[torch.Tensor], which will lead to a TypeError during compilation and tracing. I've provided suggestions to fix this. The rest of the changes look good and are consistent with the goal of the pull request.

Comment on lines +116 to +117
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The residual parameter is typed as torch.Tensor, but it can be None when called from the AiterRMSGroupQuantFP8Pattern fusion pass. This will cause a TypeError during compilation. Please change the type hint to Optional[torch.Tensor] to reflect that it can be None.

Suggested change
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor, residual: Optional[torch.Tensor], weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:

Comment on lines +156 to +157
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The residual parameter in this fake implementation is typed as torch.Tensor, but it can be None. This will cause a TypeError during fake-tensor tracing for the fusion pass. Please change the type hint to Optional[torch.Tensor].

Suggested change
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor, residual: Optional[torch.Tensor], weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:

haoyangli-amd and others added 21 commits September 25, 2025 16:16
…ect#24649)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Co-authored-by: Haoyang Li <haoyang.li@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
…mark_serving_multi_turn) (vllm-project#23255)

Signed-off-by: daniels <daniels@pliops.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: rouchenzi <ruochenwen@gmail.com>
Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com>
Co-authored-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: charlifu <charlifu@amd.com>
…llm-project#24969)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
…tract_tool_call_required_streaming (vllm-project#24668)

Signed-off-by: Shijun Yin <shijun.yin@outlook.com>
Signed-off-by: charlifu <charlifu@amd.com>
…#25065)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: root <root@cw-dfw-h100-001-305-026.cm.cluster>
Signed-off-by: charlifu <charlifu@amd.com>
…vllm-project#25046)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Dylan Maloy <34420038+dolpm@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
…mentation. (vllm-project#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: charlifu <charlifu@amd.com>
…d warning. (vllm-project#25010)

Signed-off-by: samzong <samzong.lu@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: charlifu <charlifu@amd.com>
…project#24970)

Signed-off-by: samzong <samzong.lu@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
@mergify mergify bot added ci/build deepseek Related to DeepSeek models frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models structured-output speculative-decoding v1 labels Sep 25, 2025
@mergify mergify bot added the tpu Related to Google TPUs label Sep 25, 2025
Copy link

mergify bot commented Sep 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @charlifu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 25, 2025
@ProExpertProg
Copy link
Collaborator

Thanks for submitting this, this is really exciting!

I'm currently overhauling custom op matching in #24604. We also recently added a torch implementation of group quant, could you compare its performance with AITER? Also could you compare the perf of the fused AITER kernel to the fused torch.compile kernel for rmsnorm+quant. Happy to help out with instructions, but overall:

@charlifu charlifu closed this Sep 25, 2025
@github-project-automation github-project-automation bot moved this from To Triage to Done in gpt-oss Issues & Enhancements Sep 25, 2025
@charlifu charlifu deleted the amd/aiter_fusion_pass branch September 25, 2025 16:44
@charlifu
Copy link
Contributor Author

New PR #25693

@charlifu charlifu changed the title [Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter New PR number #25693 Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) needs-rebase new-model Requests to new models performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling tpu Related to Google TPUs v1
Projects
Status: Done
Status: Done
Status: Done
Development

Successfully merging this pull request may close these issues.