-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[PERF] [Qwen3-next] Speed up gated RMSNorm #26207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant performance optimization for Gated RMSNorm by improving parallelism in the Triton kernel, especially for smaller hidden sizes. The new implementation processes multiple rows per thread block, which is a solid approach to increase GPU utilization. The benchmark results clearly demonstrate the effectiveness of this change, with substantial speedups for larger sequence lengths. The addition of a comprehensive test suite in tests/kernels/test_fla_layernorm_guard.py is excellent and ensures the correctness of this complex kernel across various configurations.
I have one suggestion to improve the thread-safety of the newly introduced device property caching mechanism. Overall, this is a great contribution that improves performance and is well-tested.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
|
@sighingnow is on vacation, I pinged some fla folks for review. |
yewentao256
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
Could you also add report for accuracy lm_eval and e2e performance vllm bench ?
added result of |
| sm_count = _get_sm_count(device) | ||
| rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) | ||
| rows_per_block = min(rows_per_block, 4) | ||
| return rows_per_block |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rows_per_block could be huge when M is large, which may cause shared memory or register spilling issues. It would be better to set a maximum value for rows_per_block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line
rows_per_block = min(rows_per_block, 4)
limit it with 4
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
|
I'd like kindly remind about review of this PR |
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: 1994 <1994@users.noreply.github.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
|
@vadiklyutiy Thanks for the effort! |
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: bbartels <benjamin@bartels.dev>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Purpose
Qwen3-next GDN attn contains a bit different version of RMSNorm - Gated RMSNorm (additionally multiply by
SiLU(y)).It is implemented in triton and previous implementation didn't extract enough parallelism due to small hidden size (128).
Add processing of several row in one block thread to speed up.
Test Plan
Added a lot of unit tests in
tests/kernels/test_fla_layernorm_guard.pyFunctional Test Result
lm_eval
Performance Test Result
Unit tests
B200
Compare of
layer_norm_fwdon input correspond to Qwen3-next with tp=4 with different sequence length.Script for benchmarking
E2E
B200