Skip to content

[ET-VK] Add fused HuggingFace RoPE operator (apply_rotary_emb_hf)#18599

Merged
SS-JIA merged 1 commit intomainfrom
gh/SS-JIA/514/orig
Mar 31, 2026
Merged

[ET-VK] Add fused HuggingFace RoPE operator (apply_rotary_emb_hf)#18599
SS-JIA merged 1 commit intomainfrom
gh/SS-JIA/514/orig

Conversation

@pytorchbot
Copy link
Copy Markdown
Collaborator

This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #18592 by @SS-JIA
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/514/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/514/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/main
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/514/orig
Differential Revision: D98741178
@diff-train-skip-merge

Pull Request resolved: #18592

Add a fused rotary positional embedding operator for the HuggingFace RoPE
convention used by Qwen3, Phi-4-mini, and other HF-based models.

The existing `et_vk.apply_rotary_emb` only matches the stock Meta/Llama RoPE
pattern (interleaved pairs via reshape+unbind+stack+flatten). HF models use a
different convention (split-half via slice+neg+cat), causing Qwen3's RoPE to
decompose into ~560 GPU dispatches per decode step instead of 16 fused
dispatches (~1,295 µs/decode, 7% of total).

This commit adds `et_vk.apply_rotary_emb_hf` with:
- Pattern matching: `HfRotaryEmbeddingPattern` in `patterns/rope_hf.py` using
  SubgraphMatcher to detect the HF RoPE graph and replace with fused op.
  Supports both full rotation (freqs_dim == head_dim) and partial rotation
  (freqs_dim < head_dim, e.g. Phi-4-mini with partial_rotary_factor=0.75)
  by registering two pattern variants in get_hf_rope_graphs().
- GLSL shader: `rotary_embedding_hf.glsl` which pairs elements at distance D/2
  (half-apart) instead of adjacent pairs, computing half_dim from the metadata
  UBO for dynamic shape support
- C++ dispatch: `add_rotary_embedding_hf_node` with corrected assertion
  (head_dim == freqs_dim, not freqs_dim*2) since HF freqs are full-dim
- Custom op registration in both xplat and fbcode
- Op tests covering multiple configurations and dynamic prefill→decode resize

Also adds a convert_phi4_mini_weights binary target to the phi_4_mini TARGETS
file to enable converting HF checkpoint weights to Meta format.

Authored with Claude.
ghstack-source-id: 359963407
@exported-using-ghexport

Differential Revision: [D98741178](https://our.internmc.facebook.com/intern/diff/D98741178/)
@pytorchbot pytorchbot requested a review from SS-JIA as a code owner March 31, 2026 01:45
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 31, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18599

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit 48db8ce with merge base d7cc5d7 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 31, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@SS-JIA SS-JIA merged commit b93a21a into main Mar 31, 2026
162 of 168 checks passed
@SS-JIA SS-JIA deleted the gh/SS-JIA/514/orig branch March 31, 2026 02:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants