Skip to content

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Jul 29, 2025

  • Implement rms norm using onnx RMSNormalization-23

  • Use the correct eps for float32

    auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
    double eps_val;
    if (acc_type == at::ScalarType::Float) {
    eps_val = eps.value_or(std::numeric_limits<float>::epsilon());
    } else {
    eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
    }
    Tensor Y = at::native::empty_like(
    *X,
    std::nullopt /* dtype */,
    std::nullopt /* layout */,
    std::nullopt /* device */,
    std::nullopt /* pin_memory */,
    LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    Tensor rstd = at::empty({M}, X->options().dtype(acc_type));
    if (M > 0) {
    RmsNormKernelImpl(*X, *gamma, M, N, eps_val, &Y, &rstd);
    }
    const auto input_shape = input.sizes();
    const size_t axis = input.dim() - normalized_shape.size();

    image

  • Created facility to run tests with the reference runtime by extending ONNXProgram and assert_onnx_program.

Fix #159257

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jul 29, 2025
Copy link

pytorch-bot bot commented Jul 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159377

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 340dc41 with merge base f4bfac1 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@titaiwangms
Copy link
Collaborator

Where do we add the tests for op23 symbolic functions?

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby added module: onnx Related to torch.onnx topic: improvements topic category labels Jul 29, 2025
@justinchuby
Copy link
Collaborator Author

Where do we add the tests for op23 symbolic functions?

I added tests in small models e2e. I think we can extend the opinfo tests later.

@justinchuby
Copy link
Collaborator Author

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 30, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@justinchuby
Copy link
Collaborator Author

In test: export with opset 23

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
@justinchuby
Copy link
Collaborator Author

test_rms_norm needs 5e-5

Signed-off-by: Justin Chu <justinchu@microsoft.com>
@justinchuby
Copy link
Collaborator Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 0 checks:

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
Signed-off-by: Justin Chu <justinchu@microsoft.com>
@justinchuby
Copy link
Collaborator Author

The default eps 1e-5 we use is too large. Reading the decomposed program we see the eps used in pytorch:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[2, 5, 7, 3]"):
             # File: /Users/justinc/Documents/GitHub/pytorch/test/onnx/exporter/test_small_models_e2e.py:772 in forward, code: return torch.nn.functional.rms_norm(x, [7, 3])
            pow_1: "f32[2, 5, 7, 3]" = torch.ops.aten.pow.Tensor_Scalar(x, 2)
            mean: "f32[2, 5, 1, 1]" = torch.ops.aten.mean.dim(pow_1, [3, 2], True);  pow_1 = None
            add: "f32[2, 5, 1, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07);  mean = None
            rsqrt: "f32[2, 5, 1, 1]" = torch.ops.aten.rsqrt.default(add);  add = None
            mul: "f32[2, 5, 7, 3]" = torch.ops.aten.mul.Tensor(x, rsqrt);  rsqrt = None
            type_as: "f32[2, 5, 7, 3]" = torch.ops.aten.type_as.default(mul, x);  mul = x = None
            return (type_as,)
            
Graph signature: 
    # inputs
    x: USER_INPUT
    
    # outputs
    type_as: USER_OUTPUT
    
Range constraints: {}

@justinchuby
Copy link
Collaborator Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@justinchuby justinchuby deleted the justinchu/rms branch July 30, 2025 20:50
@justinchuby justinchuby restored the justinchu/rms branch July 30, 2025 20:50
@justinchuby justinchuby deleted the justinchu/rms branch July 30, 2025 20:51
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
- Implement rms norm using onnx RMSNormalization-23
- Use the correct eps for float32
  https://github.com/pytorch/pytorch/blob/eaadd1282c8e66f37acf54f95668529831c95df7/aten/src/ATen/native/cuda/layer_norm_kernel.cu#L1844-L1866
  <img width="743" height="107" alt="image" src="https://github.com/user-attachments/assets/a6fd45aa-01d9-4667-924d-3012232cfcde" />

- Created facility to run tests with the reference runtime by extending ONNXProgram and assert_onnx_program.

Fix #159257
Pull Request resolved: #159377
Approved by: https://github.com/titaiwangms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ONNX] Implement aten rms norm
4 participants