Skip to content

Qualcomm AI Engine Direct - Resolved RMSNorm issue without weight#18219

Merged
abhinaykukkadapu merged 1 commit intopytorch:mainfrom
CodeLinaro:dev1/hutton/fixed_rmsnorm_wo_weight
Mar 17, 2026
Merged

Qualcomm AI Engine Direct - Resolved RMSNorm issue without weight#18219
abhinaykukkadapu merged 1 commit intopytorch:mainfrom
CodeLinaro:dev1/hutton/fixed_rmsnorm_wo_weight

Conversation

@shewu-quic
Copy link
Copy Markdown
Collaborator

Summary

  • Added a test case for torch.nn.RMSNorm([4], elementwise_affine=False).
  • Verified whether weight exists in the op builder.
    • Used an all-ones weight tensor for identity scaling if no weight is present.
  • Fix htp_rules.py RmsNorm annotator to skip weight annotation when weight is absent.

Test plan

  • FloatingPoint RMSNorm unit test
python3 backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_rms_norm  -b build-android/ -H ${HOST} -s ${SERIAL}  -m SM8750 -r /path/to/executorch -a /path/to/artifacts
  • Quantized RMSNorm unit test
python3 backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_rms_norm  -b build-android/ -H ${HOST} -s ${SERIAL}  -m SM8750 -r /path/to/executorch -a /path/to/artifacts

- Added a test case for torch.nn.RMSNorm([4], elementwise_affine=False).
- Verified whether weight exists in the op builder.
  - Used an all-ones weight tensor for identity scaling if no weight is present.
- Fix htp_rules.py RmsNorm annotator to skip weight annotation when weight is absent.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 17, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18219

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 4dac95e with merge base a81ef44 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 17, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@shewu-quic
Copy link
Copy Markdown
Collaborator Author

Hi @cccclai, @abhinaykukkadapu,

This PR addresses the issue regarding RMSNorm without weight reported by a Discord user.
Could you please review it?

Thanks,
Hutton

Copy link
Copy Markdown
Contributor

@abhinaykukkadapu abhinaykukkadapu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for fixing this.

@abhinaykukkadapu abhinaykukkadapu merged commit 75c85e7 into pytorch:main Mar 17, 2026
165 of 168 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants