Skip to content

Conversation

bbeckca
Copy link
Contributor

@bbeckca bbeckca commented Oct 4, 2025

Purpose

Uses vectorize_read_with_alignment for the variance computation
in rms_norm_kernel and rms_norm_static_fp8_quant_kernel

Test Result

Accuracy

lm_eval   --model vllm   --model_args "pretrained=Qwen/Qwen3-30B-A3B-FP8,max_model_len=32768,enforce_eager=True"   --trust_remote_code   --tasks gsm8k   --num_fewshot 5   --batch_size auto

# Before
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8279|±  |0.0104|
|     |       |strict-match    |     5|exact_match|↑  |0.8878|±  |0.0087|

# After
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8332|±  |0.0103|
|     |       |strict-match    |     5|exact_match|↑  |0.8893|±  |0.0086|
pytest tests/kernels/core/test_layernorm.py

===================================================================== test session starts =====================================================================
platform linux -- Python 3.12.3, pytest-8.4.2, pluggy-1.6.0
rootdir: /home/bbeckca/projects/vllm
configfile: pyproject.toml
plugins: anyio-4.11.0
collected 1683 items

tests/kernels/core/test_layernorm.py .................................................................................................................. [  6%]
....................................................................................................................................................... [ 15%]
....................................................................................................................................................... [ 24%]
....................................................................................................................................................... [ 33%]
....................................................................................................................................................... [ 42%]
....................................................................................................................................................... [ 51%]
....................................................................................................................................................... [ 60%]
....................................................................................................................................................... [ 69%]
....................................................................................................................................................... [ 78%]
....................................................................................................................................................... [ 87%]
....................................................................................................................................................... [ 96%]
...........................................................                                                                                             [100%]

============================================================== 1683 passed in 634.19s (0:10:34) ==============================================================

Performance

python benchmark_layernorm.py
Tokens Hidden Dtype Old (µs) New (µs) Change (%)
2 1024 half 14.092 15.787 −12.0%
2 1024 bfloat16 14.518 15.704 −8.2%
2 8192 half 15.483 14.775 +4.6%
2 8192 bfloat16 15.930 14.078 +11.6%
4 1024 half 14.601 14.583 +0.1%
4 1024 bfloat16 14.881 15.190 −2.1%
4 8192 half 14.475 14.858 −2.6%
4 8192 bfloat16 14.041 14.879 −6.0%
8 1024 half 14.674 14.144 +3.6%
8 1024 bfloat16 14.944 15.714 −5.1%
8 8192 half 15.319 14.383 +6.1%
8 8192 bfloat16 16.122 15.605 +3.2%
16 1024 half 14.788 15.497 −4.8%
16 1024 bfloat16 14.346 14.710 −2.5%
16 8192 half 15.171 14.085 +7.2%
16 8192 bfloat16 14.729 14.249 +3.3%
32 1024 half 14.495 14.328 +1.2%
32 1024 bfloat16 14.813 14.649 +1.1%
32 8192 half 14.603 14.368 +1.6%
32 8192 bfloat16 14.548 14.749 −1.4%
64 1024 half 15.304 15.691 −2.5%
64 1024 bfloat16 14.392 14.609 −1.5%
64 8192 half 14.185 15.292 −7.8%
64 8192 bfloat16 14.863 14.487 +2.5%
128 1024 half 15.093 15.548 −3.0%
128 1024 bfloat16 14.600 14.022 +4.0%
128 8192 half 14.338 14.435 −0.7%
128 8192 bfloat16 15.735 14.616 +7.1%
256 1024 half 14.518 15.267 −4.9%
256 1024 bfloat16 16.265 14.573 +10.4%
256 8192 half 15.462 14.795 +4.3%
256 8192 bfloat16 14.432 16.133 −11.8%
512 1024 half 14.407 14.439 −0.2%
512 1024 bfloat16 14.809 15.213 −2.7%
512 8192 half 13.960 14.962 −7.2%
512 8192 bfloat16 14.323 15.148 −5.8%
1024 1024 half 14.935 14.009 +6.2%
1024 1024 bfloat16 13.478 14.371 −6.6%
1024 8192 half 24.002 15.967 +33.4%
1024 8192 bfloat16 23.922 16.127 +32.6%
2048 1024 half 13.805 14.504 −5.1%
2048 1024 bfloat16 14.825 13.961 +5.8%
2048 8192 half 59.793 34.594 +42.1%
2048 8192 bfloat16 59.774 35.895 +39.9%
4096 1024 half 23.589 24.318 −3.0%
4096 1024 bfloat16 23.513 24.500 −4.2%
4096 8192 half 114.509 61.427 +46.4%
4096 8192 bfloat16 114.506 63.626 +44.5%
8192 1024 half 46.788 48.736 −4.1%
8192 1024 bfloat16 46.761 48.425 −3.6%
8192 8192 half 219.646 115.292 +47.5%
8192 8192 bfloat16 219.758 119.771 +45.5%
16384 1024 half 99.479 103.454 −4.0%
16384 1024 bfloat16 99.552 104.517 −5.0%
16384 8192 half 430.831 224.629 +47.8%
16384 8192 bfloat16 431.892 231.705 +46.3%

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces vectorization for the variance calculation in RMS norm kernels, which is a great performance improvement. The changes in rms_norm_kernel and rms_norm_static_fp8_quant_kernel replace the manual loop with vectorize_read_with_alignment.

My review includes suggestions to further optimize the vectorized operation by using packed data types and conversions for half and bfloat16 inputs. This can provide an additional performance boost by leveraging more efficient hardware instructions, aligning with the goal of this PR.

@bbeckca
Copy link
Contributor Author

bbeckca commented Oct 4, 2025

@yewentao256 Thanks for your patience. In this PR, I'm only vectorizing RMS norm variance using vectorize_read_with_alignment. It seems in order to unlock more perf for vectorization on the write path, we'll need to update the util to support passing index used to access weights. Wonder if you could check if that sounds accurate?

((scalar_t)(x * s_variance)) * weight[idx];

cc @ProExpertProg

@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 5, 2025

Why do we have performance regression in many cases?

@bbeckca
Copy link
Contributor Author

bbeckca commented Oct 5, 2025

Why do we have performance regression in many cases?

Thanks for the review! My understand is the slowdowns come from overhead from by calling vectorized path (indirection, packing/unpacking).

One way we could mitigate is by gating vectorized path to cases we see gains such as large tokens sequence and hidden size (T >= 1024 and H >= 8192). Further benchmarking can be done to find a sweet spot. Open to any feedback you may have on this.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 5, 2025

Why do we have performance regression in many cases?

Thanks for the review! My understand is the slowdowns come from overhead from by calling vectorized path (indirection, packing/unpacking).

One way we could mitigate is by gating vectorized path to cases we see gains such as large tokens sequence and hidden size (T >= 1024 and H >= 8192). Further benchmarking can be done to find a sweet spot. Open to any feedback you may have on this.

I don't think it's a good idea. Relying on data from just one machine seems risky.

@bbeckca
Copy link
Contributor Author

bbeckca commented Oct 5, 2025

@yewentao256 Thanks for your patience. In this PR, I'm only vectorizing RMS norm variance using vectorize_read_with_alignment. It seems in order to unlock more perf for vectorization on the write path, we'll need to update the util to support passing index used to access weights. Wonder if you could check if that sounds accurate?

((scalar_t)(x * s_variance)) * weight[idx];

cc @ProExpertProg

@yewentao256 Would be great to get your take on this. Since the performance gains are small and there’s pushback on heuristics, I’m leaning toward leaving it as is. I feel vectorizing writes would likely show the same pattern. Also, adding a flag or dispatcher logic feels heavy for this. Curious if there's any approach you prefer?

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! I think we don't want to introduce complexity here, so just keep like this. The performance improvement looks good to me


for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
constexpr int VEC_SIZE = 8;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, have you tried with different VEC_SIZE? This would affect perf as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yewentao256 Update: Ran benchmarks for VEC_SIZE {4, 8, 16}. I'm finding VEC_SIZE=8 performs best overall, being stable on small shapes while providing speedups on large ones. In general, gains shrink for smaller hidden size.

For example, here's what I see for H=8192 using half:
• 1024 tokens → V4 17.1 us, V8 15.9 us, V16 16.0 us
• 4096 tokens → V4 70.2 us, V8 61.4 us, V16 59.9 us
• 8192 tokens → V4 132 us, V8 115 us, V16 112 us

const float x = (float)input[blockIdx.x * input_stride + idx];
const scalar_t* input_row = input + blockIdx.x * input_stride;

constexpr int VEC_SIZE = 8;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above

@bbeckca
Copy link
Contributor Author

bbeckca commented Oct 8, 2025

Thanks for the work! I think we don't want to introduce complexity here, so just keep like this. The performance improvement looks good to me

Appreciate the review! Been low on bandwidth this week, but plan to test different configurations this weekend and follow up.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 12, 2025
@yewentao256 yewentao256 self-assigned this Oct 12, 2025
@bbeckca
Copy link
Contributor Author

bbeckca commented Oct 12, 2025

LGTM, thanks for the work!

Thanks for the many reviews and guidances! Would you be open to me exploring vectorization on the write path next, or do you think there's a more valuable follow-up task to take on from here?

@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 13, 2025

@bbeckca Could you please merge from main to fix CI error?

Copy link

mergify bot commented Oct 13, 2025

Documentation preview: https://vllm--26234.org.readthedocs.build/en/26234/

Signed-off-by: Benji Beck <benjibeck@meta.com>
@yewentao256
Copy link
Member

Thanks for the many reviews and guidances! Would you be open to me exploring vectorization on the write path next, or do you think there's a more valuable follow-up task to take on from here?

Feel free to explore both, as long as we can get better performance and not hurt accuracy, we can land it

@bbeckca
Copy link
Contributor Author

bbeckca commented Oct 14, 2025

@yewentao256 I'm encountering a variety of CI issues despite merging from main. These don't seem related to my changes. Does this suggest I should wait for issues to be fixed on main before retrying or is there some way to bypass?

@yewentao256
Copy link
Member

No worries, we can force merge with unrelated CI failures, let's try again

@yewentao256
Copy link
Member

@bbeckca Could you take a look at buildkite/ci/pr/blackwell-test?
Failure in

=========================================================================== FAILURES ===========================================================================
--
  | _________________________________________ test_all_reduce_fusion_pass_replace[dtype0-16-8-8-TestAllReduceRMSNormModel] _________________________________________

with pydantic validation error, perhaps related.

The version error in quantization-test should be fixed in main later

@bbeckca
Copy link
Contributor Author

bbeckca commented Oct 14, 2025

@bbeckca Could you take a look at buildkite/ci/pr/blackwell-test? Failure in

=========================================================================== FAILURES ===========================================================================
--
  | _________________________________________ test_all_reduce_fusion_pass_replace[dtype0-16-8-8-TestAllReduceRMSNormModel] _________________________________________

with pydantic validation error, perhaps related.

The version error in quantization-test should be fixed in main later

Sounds good. Took a quick look, this seems unrelated:

pydantic_core._pydantic_core.ValidationError: 1 validation error for ModelConfig
Value error, Invalid repository ID or local directory specified: 'nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e'.
``

@yewentao256 yewentao256 enabled auto-merge (squash) October 15, 2025 18:43
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Oct 15, 2025
@vllm-bot vllm-bot merged commit 1f491aa into vllm-project:main Oct 15, 2025
84 of 86 checks passed
mandy-li pushed a commit to mandy-li/vllm that referenced this pull request Oct 16, 2025
…-project#26234)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
…-project#26234)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
…-project#26234)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed tool-calling v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants