-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
Vectorize RMS norm variance using vectorize_read_with_alignment #26234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces vectorization for the variance calculation in RMS norm kernels, which is a great performance improvement. The changes in rms_norm_kernel
and rms_norm_static_fp8_quant_kernel
replace the manual loop with vectorize_read_with_alignment
.
My review includes suggestions to further optimize the vectorized operation by using packed data types and conversions for half
and bfloat16
inputs. This can provide an additional performance boost by leveraging more efficient hardware instructions, aligning with the goal of this PR.
@yewentao256 Thanks for your patience. In this PR, I'm only vectorizing RMS norm variance using vectorize_read_with_alignment. It seems in order to unlock more perf for vectorization on the write path, we'll need to update the util to support passing index used to access weights. Wonder if you could check if that sounds accurate? vllm/csrc/layernorm_kernels.cu Line 39 in 119f006
|
Why do we have performance regression in many cases? |
Thanks for the review! My understand is the slowdowns come from overhead from by calling vectorized path (indirection, packing/unpacking). One way we could mitigate is by gating vectorized path to cases we see gains such as large tokens sequence and hidden size (T >= 1024 and H >= 8192). Further benchmarking can be done to find a sweet spot. Open to any feedback you may have on this. |
I don't think it's a good idea. Relying on data from just one machine seems risky. |
@yewentao256 Would be great to get your take on this. Since the performance gains are small and there’s pushback on heuristics, I’m leaning toward leaving it as is. I feel vectorizing writes would likely show the same pattern. Also, adding a flag or dispatcher logic feels heavy for this. Curious if there's any approach you prefer? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work! I think we don't want to introduce complexity here, so just keep like this. The performance improvement looks good to me
|
||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||
const float x = (float)input[blockIdx.x * input_stride + idx]; | ||
constexpr int VEC_SIZE = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, have you tried with different VEC_SIZE
? This would affect perf as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yewentao256 Update: Ran benchmarks for VEC_SIZE {4, 8, 16}. I'm finding VEC_SIZE=8
performs best overall, being stable on small shapes while providing speedups on large ones. In general, gains shrink for smaller hidden size.
For example, here's what I see for H=8192 using half:
• 1024 tokens → V4 17.1 us, V8 15.9 us, V16 16.0 us
• 4096 tokens → V4 70.2 us, V8 61.4 us, V16 59.9 us
• 8192 tokens → V4 132 us, V8 115 us, V16 112 us
const float x = (float)input[blockIdx.x * input_stride + idx]; | ||
const scalar_t* input_row = input + blockIdx.x * input_stride; | ||
|
||
constexpr int VEC_SIZE = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above
Appreciate the review! Been low on bandwidth this week, but plan to test different configurations this weekend and follow up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the work!
Thanks for the many reviews and guidances! Would you be open to me exploring vectorization on the write path next, or do you think there's a more valuable follow-up task to take on from here? |
@bbeckca Could you please merge from main to fix CI error? |
38a26d4
to
faad5ac
Compare
57051fd
to
fd593fb
Compare
Documentation preview: https://vllm--26234.org.readthedocs.build/en/26234/ |
Signed-off-by: Benji Beck <benjibeck@meta.com>
fd593fb
to
236f5af
Compare
Feel free to explore both, as long as we can get better performance and not hurt accuracy, we can land it |
@yewentao256 I'm encountering a variety of CI issues despite merging from main. These don't seem related to my changes. Does this suggest I should wait for issues to be fixed on main before retrying or is there some way to bypass? |
No worries, we can force merge with unrelated CI failures, let's try again |
@bbeckca Could you take a look at buildkite/ci/pr/blackwell-test? =========================================================================== FAILURES ===========================================================================
--
| _________________________________________ test_all_reduce_fusion_pass_replace[dtype0-16-8-8-TestAllReduceRMSNormModel] _________________________________________ with pydantic validation error, perhaps related. The version error in quantization-test should be fixed in main later |
Sounds good. Took a quick look, this seems unrelated:
|
…-project#26234) Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…-project#26234) Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…-project#26234) Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
Purpose
Uses vectorize_read_with_alignment for the variance computation
in
rms_norm_kernel
andrms_norm_static_fp8_quant_kernel
Test Result
Accuracy
Performance