Skip to content

Add CoreML-stable RMSNorm for llama eager paths (#19523)#19523

Open
telgamal-1 wants to merge 1 commit into
pytorch:mainfrom
telgamal-1:export-D104862210
Open

Add CoreML-stable RMSNorm for llama eager paths (#19523)#19523
telgamal-1 wants to merge 1 commit into
pytorch:mainfrom
telgamal-1:export-D104862210

Conversation

@telgamal-1
Copy link
Copy Markdown

@telgamal-1 telgamal-1 commented May 12, 2026

Summary:

The standard RMSNorm formulation x * rsqrt(mean(x²)) * weight is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.

This diff introduces RMSNormCoreML in examples/models/llama/norm.py. The module expresses the normalization as x * sqrt(d) / vector_norm(x, dim=-1)torch.linalg.vector_norm keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.

To avoid 0 / 0 = NaN on zero-padded positions (chunked prefill in StaticAttentionIOManager pads each chunk to input_len with zeros), the denominator is floored with sqrt(dim * eps). This matches standard RMSNorm's rsqrt(mean(x²) + eps) semantics on a zero input and is large enough to survive fp16 — a plain 1e-6 underflows. Real (non-zero) tokens satisfy vector_norm(x) >> sqrt(dim * eps), so the floor is a no-op on real positions.

A new use_coreml_norm: bool = False field on ModelArgs opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs RMSNormCoreML:

  • llama_transformer.py: attention_norm, ffn_norm, the final self.norm on Transformer.
  • attention.py: q_norm_fn / k_norm_fn in the affine QK-norm path, AND the else branch of _init_qk_norms (the scaleless / non-affine QK-norm path that the original landing missed).
  • static_attention.py: q_norm / k_norm in the scaleless path, propagated through from_attention_mha by detecting rms_norm_class is RMSNormCoreML.

The QNN/HTP export path is untouched and continues to use torch.nn.RMSNorm.

Differential Revision: D104862210

@telgamal-1 telgamal-1 requested a review from lucylq as a code owner May 12, 2026 21:33
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19523

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit fc26253 with merge base d8e4ffd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 12, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 12, 2026

@telgamal-1 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104862210.

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Comment thread examples/models/llama/norm.py Outdated
eps (float, optional): Stored for API compatibility; ignored in the math.

Attributes:
eps (float): Stored for API compatibility; not consumed by `_norm`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert eps is 0 rather than silently drop it?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an explicit assert of eps is set to 0

@meta-codesync meta-codesync Bot changed the title Add CoreML-stable RMSNorm for llama eager paths Add CoreML-stable RMSNorm for llama eager paths (#19523) May 14, 2026
@telgamal-1 telgamal-1 force-pushed the export-D104862210 branch from cda18f8 to b2acb39 Compare May 14, 2026 17:43
telgamal-1 added a commit to telgamal-1/executorch that referenced this pull request May 14, 2026
Summary:

The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.

This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.

To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions.

A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`:
- `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`.
- `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed).
- `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`.

The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`.

Differential Revision: D104862210
telgamal-1 added a commit to telgamal-1/executorch that referenced this pull request May 14, 2026
Summary:

The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.

This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.

To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions.

A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`:
- `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`.
- `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed).
- `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`.

The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`.

Differential Revision: D104862210
@telgamal-1 telgamal-1 force-pushed the export-D104862210 branch from b2acb39 to b5af889 Compare May 14, 2026 17:53
self.weight.requires_grad = False


class RMSNormCoreML(torch.nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this differ from:

class CoreMLRMSNorm(torch.nn.Module):

Can we consolidate? Putting it here is fine, but then import this version into examples/apple/coreml/llama/llama_transformer.py.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imported the new version in examples/apple/coreml/llama/llama_transformer.py because it was tested to not produce NaN in QAT

Comment thread examples/models/llama/attention.py Outdated
)
if self.has_kv_weights:
self.k_norm_fn = RMSNorm(
if args.use_coreml_norm:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be integrated so far down?

Could we not leave llama_transformer/static attention as is, introduce the new norm in norm.py, and then do a module swap from RMSNorm -> CoreMLRMsNorm at export time?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addresed and now using the strategy to replace_rms_norm_for_coreml_

telgamal-1 added a commit to telgamal-1/executorch that referenced this pull request May 14, 2026
Summary:

The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.

This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.

To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions.

A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`:
- `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`.
- `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed).
- `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`.

The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`.

Differential Revision: D104862210
@telgamal-1 telgamal-1 force-pushed the export-D104862210 branch from b5af889 to ae1926c Compare May 14, 2026 20:30
Summary:

The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.

This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.

To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions.

A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`:
- `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`.
- `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed).
- `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`.

The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`.

Differential Revision: D104862210
@telgamal-1 telgamal-1 force-pushed the export-D104862210 branch from ae1926c to fc26253 Compare May 15, 2026 00:15
@metascroy
Copy link
Copy Markdown
Contributor

LGTM! You need to run the lintrunner, though

metascroy
metascroy previously approved these changes May 15, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 15, 2026

@telgamal-1 has imported this pull request. If you are a Meta employee, you can view this in D104862210.

@telgamal-1 telgamal-1 closed this May 15, 2026
@telgamal-1 telgamal-1 reopened this May 15, 2026
@pytorch-bot pytorch-bot Bot dismissed metascroy’s stale review May 15, 2026 00:54

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants