Skip to content

Conversation

@shunting314
Copy link
Contributor

@shunting314 shunting314 commented Oct 13, 2025

Stack from ghstack (oldest at bottom):

RMS/Layer norm backward would generated 2 kind of reductions:

  • the reduction computing dx which reduce across the hidden dimension (in the context of transformer)
  • the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension.

These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.

There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.

The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .

To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.

To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:

  1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening.
  2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction
  3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16).

Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165370

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 1 Pending

As of commit 1ac94b7 with merge base 4e6afa8 (image):
💚 Looks good so far! There are no failures yet. 💚

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Oct 13, 2025
ghstack-source-id: c1adef9
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 13, 2025
ghstack-source-id: 0920256
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 14, 2025
ghstack-source-id: b4d719c
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 14, 2025
ghstack-source-id: eb43da8
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 9ae0d8d
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 3a8c374
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: bc8c506
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 8942f07
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 9e687ab
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 869f0ac
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 2f0453d
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 5931165
Pull Request resolved: #165370
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: e312754
Pull Request resolved: #165370
@PaulZhang12
Copy link
Contributor

This is awesome! Can we have some more benchmarking results against Quack for H100 and especially B200 norm backwards?

RMS/Layer norm backward would generated 2 kind of reductions:
- the reduction computing dx which reduce across the hidden dimension (in the context of transformer)
- the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension.

These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.

There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.

The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .

To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.

To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening.
2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction
3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16).

Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben 


[ghstack-poisoned]

# decide the split size
nrow, ncol = node1.group[1]
split_size = 64 # TODO need add heuristics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to autotune over this?

Copy link
Contributor Author

@shunting314 shunting314 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is critical to perf. The current status is, if we do split reduction, the split size is already decided by how we pick 'num-split'.

A few steps I'm trying

  1. make the split size independent to num-splits decided by split reduction . The next PR is a first step to make sure we can still fuse mix order reductions even when the outer reduction is split
  2. I'll first check if I can come up with a good enough heuristics. If there is no heuristics that works well for all important shapes, then do autotuning before codegening the kernel. (Similar implementation as how we benchmark fusion)

RMS/Layer norm backward would generated 2 kind of reductions:
- the reduction computing dx which reduce across the hidden dimension (in the context of transformer)
- the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension.

These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.

There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.

The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .

To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.

To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening.
2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction
3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16).

Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben 


[ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions:
- the reduction computing dx which reduce across the hidden dimension (in the context of transformer)
- the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension.

These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.

There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.

The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .

To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.

To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening.
2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction
3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16).

Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben 


[ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions:
- the reduction computing dx which reduce across the hidden dimension (in the context of transformer)
- the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension.

These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.

There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.

The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .

To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.

To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening.
2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction
3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16).

Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben 


[ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions:
- the reduction computing dx which reduce across the hidden dimension (in the context of transformer)
- the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension.

These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.

There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.

The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .

To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.

To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening.
2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction
3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16).

Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben 


[ghstack-poisoned]
@shunting314 shunting314 mentioned this pull request Oct 27, 2025
RMS/Layer norm backward would generated 2 kind of reductions:
- the reduction computing dx which reduce across the hidden dimension (in the context of transformer)
- the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension.

These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.

There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.

The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .

To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.

To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening.
2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction
3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16).

Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben 


[ghstack-poisoned]
enable_pdl = False

mix_order_reduction = (
os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0") == "1"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off by default in this PR. Will reenable in following PRs

@shunting314
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants