-
Couldn't load subscription status.
- Fork 25.7k
[inductor] generate fused rms/layer norm bwd #165370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165370
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 1 PendingAs of commit 1ac94b7 with merge base 4e6afa8 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
This is awesome! Can we have some more benchmarking results against Quack for H100 and especially B200 norm backwards? |
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
|
||
| # decide the split size | ||
| nrow, ncol = node1.group[1] | ||
| split_size = 64 # TODO need add heuristics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to autotune over this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is critical to perf. The current status is, if we do split reduction, the split size is already decided by how we pick 'num-split'.
A few steps I'm trying
- make the split size independent to num-splits decided by split reduction . The next PR is a first step to make sure we can still fuse mix order reductions even when the outer reduction is split
- I'll first check if I can come up with a good enough heuristics. If there is no heuristics that works well for all important shapes, then do autotuning before codegening the kernel. (Similar implementation as how we benchmark fusion)
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
RMS/Layer norm backward would generated 2 kind of reductions: - the reduction computing dx which reduce across the hidden dimension (in the context of transformer) - the reduction computing dw/db which reduce across the BxT (batch size , sequence length) dimension. These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders. There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically. The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 . To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload. To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following: 1. convert the outer reductions to pointwise by replacing 'reduction' & 'store_reduction' node with a new type of node 'parital_accumulate'. The new node will collect the reduction type, buffer name, input of reduction etc, which is essential for proper codegening. 2. do loop reordering (rely on the earlier loop ordering after fusion work) to reorder the loops of the converted pointwise so it can be fused with the inner reduction 3. there can be epilogues that need to be added in the end. E.g. the outer reduction may be followed by a division for mean , or followed by a down cast if dw/db is in low precision (fp16/bf16). Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
| enable_pdl = False | ||
|
|
||
| mix_order_reduction = ( | ||
| os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0") == "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off by default in this PR. Will reenable in following PRs
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Signed-off-by: xinan.lin <xinan.lin@intel.com>
Stack from ghstack (oldest at bottom):
RMS/Layer norm backward would generated 2 kind of reductions:
These 2 set of reductions have common input buffers but inductor can not fuse them because of different loop orders.
There are multiple sources of custom kernels that implement fused version of such kernel (Liger-Kernel, quack, Paul Zhang's internal post). This PR enable Inductor to generate such kernels automatically.
The generated kernel is very similar to https://github.com/linkedin/Liger-Kernel/blob/33924d20b6fa5dfc9391e4630ae82cc810114a30/src/liger_kernel/ops/rms_norm.py#L114 .
To make the implementation simple and performing, we enable such fusion only if the inner reduction (computing dx) is a persistent reduction. This should be true for representative inputs. Persistent reduction is critical for the perf here to make sure a loaded tensor does not need to be reload.
To make sure the inner reduction (computing dx) and outer reductions (computing dw/db) being fusible, the PR does the following:
Some early benchmarking on H100 shows about 2X speedup for both RMSNorm and LayerNorm backward for shape (1152 * 500, 384 ) used in some internal model. Note that, I manually disable split reduction in this benchmarking since otherwise the fusion will be skipped right now. The next PR will make the mix-order-reduction compose better with split reduction
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben