[ET-VK] Fused RMSNorm operator to fix fp16 overflow#18772
[ET-VK] Fused RMSNorm operator to fix fp16 overflow#18772meta-codesync[bot] merged 3 commits intogh/SS-JIA/518/basefrom
Conversation
Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18772
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 9864405 with merge base 4afd7f9 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/) ghstack-source-id: 364237333 Pull Request resolved: #18772
This PR needs a
|
Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/) [ghstack-poisoned]
Pull Request resolved: #18772 Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. ghstack-source-id: 364280899 @exported-using-ghexport Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/)
Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/) [ghstack-poisoned]
Pull Request resolved: #18772 Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. ghstack-source-id: 364514329 @exported-using-ghexport Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/)
76a5a52
into
gh/SS-JIA/518/base
Pull Request resolved: #18772 Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. ghstack-source-id: 364514329 @exported-using-ghexport Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/)
Pull Request resolved: #18772 Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. ghstack-source-id: 364514329 @exported-using-ghexport Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/)
Pull Request resolved: pytorch#18772 Fused RMSNorm operator that performs squaring, mean, rsqrt, and weight scaling in a single shader dispatch. All accumulation is done in fp32 regardless of input dtype, preventing fp16 overflow when residual stream values exceed sqrt(65504) ≈ 256. The Python reference impl (`rms_norm_impl`) must preserve the input dtype — PyTorch type promotion would otherwise produce fp32 output from fp16 inputs, and the FusePatternsPass re-trace would propagate that incorrect dtype through the graph. Authored by Claude. ghstack-source-id: 364514329 @exported-using-ghexport Differential Revision: [D99841211](https://our.internmc.facebook.com/intern/diff/D99841211/)
Stack from ghstack (oldest at bottom):
Fused RMSNorm operator that performs squaring, mean, rsqrt, and
weight scaling in a single shader dispatch. All accumulation is done
in fp32 regardless of input dtype, preventing fp16 overflow when
residual stream values exceed sqrt(65504) ≈ 256.
The Python reference impl (
rms_norm_impl) must preserve the inputdtype — PyTorch type promotion would otherwise produce fp32 output
from fp16 inputs, and the FusePatternsPass re-trace would propagate
that incorrect dtype through the graph.
Authored by Claude.
Differential Revision: D99841211