-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ONNX] RMS Norm #159377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] RMS Norm #159377
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159377
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 340dc41 with merge base f4bfac1 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Where do we add the tests for op23 symbolic functions? |
I added tests in small models e2e. I think we can extend the opinfo tests later. |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
In test: export with opset 23 |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
test_rms_norm needs 5e-5 |
Signed-off-by: Justin Chu <justinchu@microsoft.com>
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 0 checks: Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
The default eps 1e-5 we use is too large. Reading the decomposed program we see the eps used in pytorch: ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[2, 5, 7, 3]"):
# File: /Users/justinc/Documents/GitHub/pytorch/test/onnx/exporter/test_small_models_e2e.py:772 in forward, code: return torch.nn.functional.rms_norm(x, [7, 3])
pow_1: "f32[2, 5, 7, 3]" = torch.ops.aten.pow.Tensor_Scalar(x, 2)
mean: "f32[2, 5, 1, 1]" = torch.ops.aten.mean.dim(pow_1, [3, 2], True); pow_1 = None
add: "f32[2, 5, 1, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07); mean = None
rsqrt: "f32[2, 5, 1, 1]" = torch.ops.aten.rsqrt.default(add); add = None
mul: "f32[2, 5, 7, 3]" = torch.ops.aten.mul.Tensor(x, rsqrt); rsqrt = None
type_as: "f32[2, 5, 7, 3]" = torch.ops.aten.type_as.default(mul, x); mul = x = None
return (type_as,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
type_as: USER_OUTPUT
Range constraints: {} |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
- Implement rms norm using onnx RMSNormalization-23 - Use the correct eps for float32 https://github.com/pytorch/pytorch/blob/eaadd1282c8e66f37acf54f95668529831c95df7/aten/src/ATen/native/cuda/layer_norm_kernel.cu#L1844-L1866 <img width="743" height="107" alt="image" src="https://github.com/user-attachments/assets/a6fd45aa-01d9-4667-924d-3012232cfcde" /> - Created facility to run tests with the reference runtime by extending ONNXProgram and assert_onnx_program. Fix #159257 Pull Request resolved: #159377 Approved by: https://github.com/titaiwangms
Implement rms norm using onnx RMSNormalization-23
Use the correct eps for float32
pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu
Lines 1844 to 1866 in eaadd12
Created facility to run tests with the reference runtime by extending ONNXProgram and assert_onnx_program.
Fix #159257