-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility
#28500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility
#28500
Conversation
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly enables support for the fused_qknorm_rope_kernel on the ROCm platform. The changes are well-targeted and address the import errors and compatibility issues. Key changes include enabling the kernel compilation for ROCm, adding support for 64-bit warp masks required by __shfl_xor_sync on ROCm, and enabling BFloat16 data type support. The corresponding test files have also been updated to run on ROCm, ensuring the changes are validated. The implementation appears solid and I don't see any issues. Good work.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for (int i = 0; i < numElemsPerThread; i++) { | ||
| elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); | ||
| elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16); | ||
| if (laneId < 16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Limit warp shuffle width to 32 on ROCm
When compiling for ROCm, __shfl_xor_sync defaults to a width equal to the hardware wavefront (64 lanes). The kernel logic assumes 32‑lane groups – warpsPerBlock, laneId, and the head layout all treat 32 threads as one warp – so the shuffle in the non‑interleaved branch must also be restricted to 32 lanes. Leaving it at the default 64 mixes data between two logical warps on AMD GPUs, corrupting the RoPE transformation for Neox models. Pass an explicit width of 32 (as is done in warpReduceSum) to keep the shuffle confined to the intended half‑warp.
Useful? React with 👍 / 👎.
|
Thanks for this fix, also cc @ProExpertProg |
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix!
|
The LoRA TP test is failing on latest nightly https://buildkite.com/vllm/ci/builds/38575/steps/canvas?jid=019a7670-4fa7-4d31-b8f5-f1e1a4e30296 Force-merging now |
|
@tjtanaa @DarkLight1337 Also fix RDNA3 |
|
Hi @tjtanaa, we are getting issues when running with rocm < 7 (6.4.2-120), it seems that vllm would fail with error when compiling the fused kernel: Do you have suggestion that how we could solve this? For example, using a macro to check if we are < ROCM70 if that's possible. |
I'm seeing the same issue. It should be fixed in this PR #28682 |
…lm-project#28500) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: George D. Torres <gdavtor@gmail.com>
…lm-project#28500) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Bram Wasti <bwasti@meta.com>
Purpose
Bugfix for issue #28501 .
This PR #27165 introduced
fused_qknorm_rope_kernelkernel. Since it is not compiled on ROCm, and the import statement is not handled well, it is throwing import errorThe proposed bugfix is to fix the compatibility and compile the ops for ROCm. Currently is validated and supported on MI300X.
The details of the fixes:
__shfl_xor_syncThis fix is compatible for ROCm7 and above only.
Test Plan
Pass all of the unit tests that are related to this ops:
tests/kernels/core/test_fused_qk_norm_rope.pytests/compile/test_qk_norm_rope_fusion.pyCommand for the fusion pass.
Op replacement log:
[1;36m(Worker_TP1 pid=135636)^[[0;0m DEBUG 11-12 01:36:15 [compilation/qk_norm_rope_fusion.py:235] Fused QK Norm+RoPE on 1 sitesTest Result
All tests passed:
tests/kernels/core/test_fused_qk_norm_rope.py8/8 tests passedtests/compile/test_qk_norm_rope_fusion.py16/16 tests passedQuality Metrics (GSM8K)
Performance Metrics
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.