Skip to content

Conversation

@tjtanaa
Copy link
Collaborator

@tjtanaa tjtanaa commented Nov 11, 2025

Purpose

Bugfix for issue #28501 .

This PR #27165 introduced fused_qknorm_rope_kernel kernel. Since it is not compiled on ROCm, and the import statement is not handled well, it is throwing import error

EngineCore_DP0 pid=1115)   File "/app/rocmvllm/fix-fastsafetensors/vllm/compilation/decorators.py", line 293, in __init__                         
(EngineCore_DP0 pid=1115)     TorchCompileWrapperWithCustomDispatcher.__init__(                                                                    
(EngineCore_DP0 pid=1115)   File "/app/rocmvllm/fix-fastsafetensors/vllm/compilation/wrapper.py", line 42, in __init__                             
(EngineCore_DP0 pid=1115)     backend = vllm_config.compilation_config.init_backend(vllm_config)                                                   
(EngineCore_DP0 pid=1115)               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                   
(EngineCore_DP0 pid=1115)   File "/app/rocmvllm/fix-fastsafetensors/vllm/config/compilation.py", line 770, in init_backend                         
(EngineCore_DP0 pid=1115)     from vllm.compilation.backends import VllmBackend                                                                    
(EngineCore_DP0 pid=1115)   File "/app/rocmvllm/fix-fastsafetensors/vllm/compilation/backends.py", line 40, in <module>                            
(EngineCore_DP0 pid=1115)     from .pass_manager import PostGradPassManager                                                                        
(EngineCore_DP0 pid=1115)   File "/app/rocmvllm/fix-fastsafetensors/vllm/compilation/pass_manager.py", line 20, in <module>                        
(EngineCore_DP0 pid=1115)     from .qk_norm_rope_fusion import QKNormRoPEFusionPass                                                                
(EngineCore_DP0 pid=1115)   File "/app/rocmvllm/fix-fastsafetensors/vllm/compilation/qk_norm_rope_fusion.py", line 24, in <module>                 
(EngineCore_DP0 pid=1115)     FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default                                                           
(EngineCore_DP0 pid=1115)                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                   
(EngineCore_DP0 pid=1115)   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1364, in __getattr__                                
(EngineCore_DP0 pid=1115)     raise AttributeError(                                                                                                
(EngineCore_DP0 pid=1115) AttributeError: '_OpNamespace' '_C' object has no attribute 'fused_qk_norm_rope'  

The proposed bugfix is to fix the compatibility and compile the ops for ROCm. Currently is validated and supported on MI300X.

The details of the fixes:

  1. Enable compilation for ROCm platform
  2. Support 64-bit warp masks (vs 32-bit on CUDA) for __shfl_xor_sync
  3. Enable BFloat16 support for ROCm 7.0+

This fix is compatible for ROCm7 and above only.

Test Plan

Pass all of the unit tests that are related to this ops:

  1. tests/kernels/core/test_fused_qk_norm_rope.py
  2. tests/compile/test_qk_norm_rope_fusion.py

Command for the fusion pass.

#!/bin/bash
rm -rf /root/.cache/vllm/
echo "Qwen/Qwen3-30B-A3B-FP8"
export VLLM_RPC_TIMEOUT=1800000
export SAFETENSORS_FAST_GPU=1
export MODEL_PATH=Qwen/Qwen3-30B-A3B-FP8
vllm serve $MODEL_PATH \
-tp 2 \
--trust-remote-code \
--disable-log-requests \
--compilation-config '{"pass_config": {"enable_qk_norm_rope_fusion": 1, "enable_noop": true}, "custom_ops": ["+rms_norm", "+rotary_embedding"]}'

Op replacement log:

[1;36m(Worker_TP1 pid=135636)^[[0;0m DEBUG 11-12 01:36:15 [compilation/qk_norm_rope_fusion.py:235] Fused QK Norm+RoPE on 1 sites

Test Result

All tests passed:

  1. tests/kernels/core/test_fused_qk_norm_rope.py 8/8 tests passed
  2. tests/compile/test_qk_norm_rope_fusion.py 16/16 tests passed

Quality Metrics (GSM8K)

Metric Without Fusion Pass With Fusion Pass
Flexible-extract exact_match 0.8340 ± 0.0102 0.8279 ± 0.0104
Strict-match exact_match 0.8886 ± 0.0087 0.8863 ± 0.0087

Performance Metrics

Metric Without Fusion Pass With Fusion Pass Change
Benchmark duration (s) 38.87 38.60 -0.27s
Request throughput (req/s) 0.82 0.83 +0.01
Output token throughput (tok/s) 843.04 848.90 +5.86
Peak output token throughput (tok/s) 896.00 904.00 +8.00
Total token throughput (tok/s) 1686.09 1697.81 +11.72
Latency Metrics
Mean TTFT (ms) 324.17 336.16 +11.99
Median TTFT (ms) 356.31 360.60 +4.29
P99 TTFT (ms) 390.08 388.64 -1.44
Mean TPOT (ms) 9.16 9.09 -0.07
Median TPOT (ms) 9.11 9.04 -0.07
P99 TPOT (ms) 9.41 9.36 -0.05
Mean ITL (ms) 9.19 9.11 -0.08
Median ITL (ms) 9.09 9.00 -0.09
P99 ITL (ms) 9.40 9.38 -0.02

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Nov 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly enables support for the fused_qknorm_rope_kernel on the ROCm platform. The changes are well-targeted and address the import errors and compatibility issues. Key changes include enabling the kernel compilation for ROCm, adding support for 64-bit warp masks required by __shfl_xor_sync on ROCm, and enabling BFloat16 data type support. The corresponding test files have also been updated to run on ROCm, ensuring the changes are validated. The implementation appears solid and I don't see any issues. Good work.

@tjtanaa tjtanaa marked this pull request as draft November 12, 2025 00:00
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa marked this pull request as ready for review November 12, 2025 01:54
@tjtanaa tjtanaa requested a review from gshtras November 12, 2025 01:54
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 251 to 253
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
if (laneId < 16) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Limit warp shuffle width to 32 on ROCm

When compiling for ROCm, __shfl_xor_sync defaults to a width equal to the hardware wavefront (64 lanes). The kernel logic assumes 32‑lane groups – warpsPerBlock, laneId, and the head layout all treat 32 threads as one warp – so the shuffle in the non‑interleaved branch must also be restricted to 32 lanes. Leaving it at the default 64 mixes data between two logical warps on AMD GPUs, corrupting the RoPE transformation for Neox models. Pass an explicit width of 32 (as is done in warpReduceSum) to keep the shuffle confined to the intended half‑warp.

Useful? React with 👍 / 👎.

@izhuhaoran
Copy link
Contributor

Thanks for this fix, also cc @ProExpertProg

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix!

@ProExpertProg ProExpertProg enabled auto-merge (squash) November 12, 2025 02:28
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 12, 2025
@DarkLight1337
Copy link
Member

The LoRA TP test is failing on latest nightly https://buildkite.com/vllm/ci/builds/38575/steps/canvas?jid=019a7670-4fa7-4d31-b8f5-f1e1a4e30296

Force-merging now

@vllm-bot vllm-bot merged commit edb59a9 into vllm-project:main Nov 12, 2025
90 of 92 checks passed
@JartX
Copy link
Contributor

JartX commented Nov 12, 2025

@tjtanaa @DarkLight1337 Also fix RDNA3

@tjtanaa tjtanaa deleted the bugfix-fused_qknorm_rope branch November 12, 2025 17:40
@liuzijing2014
Copy link
Collaborator

liuzijing2014 commented Nov 13, 2025

Hi @tjtanaa, we are getting issues when running with rocm < 7 (6.4.2-120), it seems that vllm would fail with error when compiling the fused kernel:

    qknorm_rope_kernel.hip:248:18: error: too few arguments to function call, expected at least 1, have 0
      248 |       __syncwarp();
          |       ~~~~~~~~~~ ^
    buck-out/v2/gen/fbcode/73b675c30aa38b90/vllm/trunk/__vllm_cpp_lib_hipify_gen__/out/csrc/fused_qknorm_rope_kernel.hip:268:7: error: use of undeclared identifier '__syncwarp'; did you mean '__sync_swap'?
      268 |       __syncwarp();
          |       ^~~~~~~~~~
          |       __sync_swap
    buck-out/v2/gen/fbcode/73b675c30aa38b90/vllm/trunk/__vllm_cpp_lib_hipify_gen__/out/csrc/fused_qknorm_rope_kernel.hip:248:7: note: '__sync_swap' declared here
      248 |       __syncwarp();
          |       ^
    buck-out/v2/gen/fbcode/73b675c30aa38b90/vllm/trunk/__vllm_cpp_lib_hipify_gen__/out/csrc/fused_qknorm_rope_kernel.hip:268:18: error: too few arguments to function call, expected at least 1, have 0
      268 |       __syncwarp();
          |       ~~~~~~~~~~ ^
    4 errors generated when compiling for gfx942.

Do you have suggestion that how we could solve this? For example, using a macro to check if we are < ROCM70 if that's possible.

cc @houseroad @842974287

@SageMoore
Copy link
Contributor

Hi @tjtanaa, we are getting issues when running with rocm < 7 (6.4.2-120), it seems that vllm would fail with error when compiling the fused kernel:

    qknorm_rope_kernel.hip:248:18: error: too few arguments to function call, expected at least 1, have 0
      248 |       __syncwarp();
          |       ~~~~~~~~~~ ^
    buck-out/v2/gen/fbcode/73b675c30aa38b90/vllm/trunk/__vllm_cpp_lib_hipify_gen__/out/csrc/fused_qknorm_rope_kernel.hip:268:7: error: use of undeclared identifier '__syncwarp'; did you mean '__sync_swap'?
      268 |       __syncwarp();
          |       ^~~~~~~~~~
          |       __sync_swap
    buck-out/v2/gen/fbcode/73b675c30aa38b90/vllm/trunk/__vllm_cpp_lib_hipify_gen__/out/csrc/fused_qknorm_rope_kernel.hip:248:7: note: '__sync_swap' declared here
      248 |       __syncwarp();
          |       ^
    buck-out/v2/gen/fbcode/73b675c30aa38b90/vllm/trunk/__vllm_cpp_lib_hipify_gen__/out/csrc/fused_qknorm_rope_kernel.hip:268:18: error: too few arguments to function call, expected at least 1, have 0
      268 |       __syncwarp();
          |       ~~~~~~~~~~ ^
    4 errors generated when compiling for gfx942.

Do you have suggestion that how we could solve this? For example, using a macro to check if we are < ROCM70 if that's possible.

cc @houseroad @842974287

I'm seeing the same issue. It should be fixed in this PR #28682

geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
…lm-project#28500)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: George D. Torres <gdavtor@gmail.com>
bwasti pushed a commit to bwasti/vllm that referenced this pull request Nov 17, 2025
…lm-project#28500)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Bram Wasti <bwasti@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] [ROCm]: AttributeError: '_OpNamespace' '_C' object has no attribute 'fused_qk_norm_rope'

8 participants