Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize reduction + amax fusion #111122

Closed

Conversation

ipiszy
Copy link
Contributor

@ipiszy ipiszy commented Oct 12, 2023

This PR optimizes cases like layer_norm + fp8 quant (which includes amax and fp8 quant) fusion when amax is split into multiple reduction kernels.

Benchmark:

python test/inductor/test_fp8.py -k test_layernorm_fp8_quant_benchmark

Before this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.13262102689486555ms, Eager: 0.8211962616822429ms, LN only Inductor: 0.09606276150627614ms.

After this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.08281274131274131ms, Eager: 0.8217452830188678ms, LN only Inductor: 0.09586902286902287ms.

LN + fp8 quant is even faster than LN itself. The reason could be that LN + fp8 outputs fp8 while LN outputs fp16.

From Inductor nightly benchmark test:
There are perf differences in cuda_graph / cuda_graph_dynamic / default runs, but no difference in inductor_max_autotune. So it seems to me that the perf differences are mostly like fluctuations.

Screenshot 2023-10-18 at 4 58 55 PM

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 12, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111122

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 2da5e75 with merge base 547a116 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ipiszy added a commit that referenced this pull request Oct 12, 2023
ghstack-source-id: ade61469e2eebcb6494ba6bae88e474bc94f87cf
Pull Request resolved: #111122
# Input node has already been realized. Return its size and reduction_size.
return input_node.get_size(), input_node.get_reduction_size()

# This is one issue: what if there are permutations between the input node and its dependent realized nodes?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel Wonder do you have any suggestions for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to permutations there are views which change the ndimension.

Is it ok if this function is approximate? Or are there correctness issues if it is wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using reduction_sizes from dependent nodes have a better chance to fuse these nodes.

e.g. The current case is:

x1 = layer_norm(x0)
x2 = amax(x1)
x3 = to_fp8(x1)

Inductor generates these nodes:

n0=WelfordReduction()
n1=WelfordReduction()
n2=WelfordReduction()
n3=Pointwise()
n4=Reduction()
n5=Pointwise()

Currently n0, n1, n2, n3, n5 are fused together. n3, n4 are fused together.
I'd like to make first level reduction ranges of n4 the same as n0 / n1 / n2, so that n0, n1, n2, n3, first level n4, n5 can be fused together.

So it seem to me that we cannot use approximate values here for n4 reduction sizes.

@ipiszy ipiszy requested review from jansel and drisspg October 12, 2023 06:32
batch_size, sequence_length, hidden_size = shape

def amax_fp8(x: Tensor, scale: Tensor):
y = torch.max(torch.abs(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this use torch.amax instead of older torch.max? If max is not intentional, I think using amax to mean "return the values without indices" is clearer

This PR optimizes cases like layer_norm + fp8 quant (which includes amax and fp8 quant) fusion when amax is split into multiple reduction kernels.

Benchmark:
```
python test/inductor/test_fp8.py -k test_layernorm_fp8_quant_benchmark

Before this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.13262102689486555ms, Eager: 0.8211962616822429ms, LN only Inductor: 0.09606276150627614ms.

After this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.08281274131274131ms, Eager: 0.8217452830188678ms, LN only Inductor: 0.09586902286902287ms.
```

LN + fp8 quant is even faster than LN itself. The reason could be that LN + fp8 outputs fp8 while LN outputs fp16.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Oct 13, 2023
ghstack-source-id: b7c11cfb4c03156c3eb9e0f1198e34321f080910
Pull Request resolved: #111122
This PR optimizes cases like layer_norm + fp8 quant (which includes amax and fp8 quant) fusion when amax is split into multiple reduction kernels.

Benchmark:
```
python test/inductor/test_fp8.py -k test_layernorm_fp8_quant_benchmark

Before this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.13262102689486555ms, Eager: 0.8211962616822429ms, LN only Inductor: 0.09606276150627614ms.

After this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.08281274131274131ms, Eager: 0.8217452830188678ms, LN only Inductor: 0.09586902286902287ms.
```

LN + fp8 quant is even faster than LN itself. The reason could be that LN + fp8 outputs fp8 while LN outputs fp16.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Oct 15, 2023
ghstack-source-id: bfc21e5da21b4fce616d80c9ce195a4673d27e79
Pull Request resolved: #111122
This PR optimizes cases like layer_norm + fp8 quant (which includes amax and fp8 quant) fusion when amax is split into multiple reduction kernels.

Benchmark:
```
python test/inductor/test_fp8.py -k test_layernorm_fp8_quant_benchmark

Before this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.13262102689486555ms, Eager: 0.8211962616822429ms, LN only Inductor: 0.09606276150627614ms.

After this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.08281274131274131ms, Eager: 0.8217452830188678ms, LN only Inductor: 0.09586902286902287ms.
```

LN + fp8 quant is even faster than LN itself. The reason could be that LN + fp8 outputs fp8 while LN outputs fp16.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Oct 16, 2023
ghstack-source-id: f9abe7a11dafa5dd3095cd2d46029c029271424d
Pull Request resolved: #111122
This PR optimizes cases like layer_norm + fp8 quant (which includes amax and fp8 quant) fusion when amax is split into multiple reduction kernels.

Benchmark:
```
python test/inductor/test_fp8.py -k test_layernorm_fp8_quant_benchmark

Before this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.13262102689486555ms, Eager: 0.8211962616822429ms, LN only Inductor: 0.09606276150627614ms.

After this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.08281274131274131ms, Eager: 0.8217452830188678ms, LN only Inductor: 0.09586902286902287ms.
```

LN + fp8 quant is even faster than LN itself. The reason could be that LN + fp8 outputs fp8 while LN outputs fp16.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Oct 17, 2023
ghstack-source-id: 8ea42b1fe47bc07024af40b92d6062b31b2b3834
Pull Request resolved: #111122

from .ir import ComputedBuffer, Loops

if not isinstance(input_node.data.data, Loops):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need some checks to ensure .data and .data.data exist. There are some cases like views that result in different nesting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure. I added some checks in the callsite, let me also add checks here for safety.

Comment on lines +376 to +378
if hasattr(input_node, "get_size") and hasattr(
input_node, "get_reduction_size"
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a method would be cleaner than these hasattr checks.

This PR optimizes cases like layer_norm + fp8 quant (which includes amax and fp8 quant) fusion when amax is split into multiple reduction kernels.

Benchmark:
```
python test/inductor/test_fp8.py -k test_layernorm_fp8_quant_benchmark

Before this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.13262102689486555ms, Eager: 0.8211962616822429ms, LN only Inductor: 0.09606276150627614ms.

After this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096). 
Benchmark results: Inductor: 0.08281274131274131ms, Eager: 0.8217452830188678ms, LN only Inductor: 0.09586902286902287ms.
```

LN + fp8 quant is even faster than LN itself. The reason could be that LN + fp8 outputs fp8 while LN outputs fp16.

From Inductor nightly benchmark test:
There are perf differences in cuda_graph / cuda_graph_dynamic / default runs, but no difference in inductor_max_autotune. So it seems to me that the perf differences are mostly like fluctuations. 

![Screenshot 2023-10-18 at 4 58 55 PM](https://github.com/pytorch/pytorch/assets/10527447/6640474a-1e1d-4d33-97e9-0a60d0bc9f1f)






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Oct 19, 2023
ghstack-source-id: 8f2ec448b4f1ff768402b93b15e87e7140fe9d22
Pull Request resolved: #111122
Copy link
Contributor Author

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jansel !


from .ir import ComputedBuffer, Loops

if not isinstance(input_node.data.data, Loops):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure. I added some checks in the callsite, let me also add checks here for safety.

@ipiszy
Copy link
Contributor Author

ipiszy commented Oct 19, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 19, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@ipiszy
Copy link
Contributor Author

ipiszy commented Oct 19, 2023

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 19, 2023
@ipiszy
Copy link
Contributor Author

ipiszy commented Oct 19, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/ipiszy@gmail.com/11/head branch October 23, 2023 14:24
ipiszy added a commit that referenced this pull request Oct 23, 2023
In #111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.


Differential Revision: [D50544876](https://our.internmc.facebook.com/intern/diff/D50544876)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Oct 23, 2023
Summary:

In #111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.




imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D50544876

Pulled By: ipiszy
ipiszy added a commit that referenced this pull request Oct 23, 2023
Summary:


In #111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.




imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D50544876

Pulled By: ipiszy
@mlazos
Copy link
Contributor

mlazos commented Oct 24, 2023

@ipiszy This PR caused a significant regression in TIMM dm_nfnet_f0

repro command:

python timm_models.py --training --amp --performance --only=dm_nfnet_f0 --inductor

Can you take a look?

cc @eellison

pytorchmergebot pushed a commit that referenced this pull request Oct 24, 2023
In #111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.

Differential Revision: [D50544876](https://our.internmc.facebook.com/intern/diff/D50544876)

Pull Request resolved: #111781
Approved by: https://github.com/malfet, https://github.com/jansel
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
In pytorch#111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.

Differential Revision: [D50544876](https://our.internmc.facebook.com/intern/diff/D50544876)

Pull Request resolved: pytorch#111781
Approved by: https://github.com/malfet, https://github.com/jansel
pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2023
In #111122, an optimization is introduced for reduction + pointwise + multi-level reduction fusion. The main idea of this optimization is to have the first-level reduction of the multi-level reduction reuses the reduction sizes of the first reduction kernel so that there are better chances that the first reduction kernel and the first-level reduction of the multi-level reduction kernel can be fused. However, it introduces a bug for pattern pointwise + multi-level reduction, where the first-level reduction kernel wrongly reuses the reduction ranges (which is []) from the previous pointwise kernel. This PR fixes this issue.

Test plan:
`python timm_models.py --training --amp --performance --only=dm_nfnet_f0 --inductor`
Results before this PR: 0.869x
Results after this PR: 1.232x

Benchmark results:
![Screenshot 2023-10-30 at 2 30 10 PM](https://github.com/pytorch/pytorch/assets/10527447/c7b241c0-92a4-49ff-96fb-2805c8fcc45a)

<img width="1491" alt="Screenshot 2023-10-30 at 3 10 06 PM" src="https://github.com/pytorch/pytorch/assets/10527447/608d26ea-dcc5-4f2a-8700-4a928701392b">

Pull Request resolved: #112297
Approved by: https://github.com/jansel
@ipiszy
Copy link
Contributor Author

ipiszy commented Nov 1, 2023

@ipiszy This PR caused a significant regression in TIMM dm_nfnet_f0

repro command:

python timm_models.py --training --amp --performance --only=dm_nfnet_f0 --inductor

Can you take a look?

cc @eellison

FYI this is fixed by #112297.

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
This PR optimizes cases like layer_norm + fp8 quant (which includes amax and fp8 quant) fusion when amax is split into multiple reduction kernels.

Benchmark:
```
python test/inductor/test_fp8.py -k test_layernorm_fp8_quant_benchmark

Before this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096).
Benchmark results: Inductor: 0.13262102689486555ms, Eager: 0.8211962616822429ms, LN only Inductor: 0.09606276150627614ms.

After this PR:
Config: float8_dtype=torch.float8_e5m2, shape=(4, 2048, 4096).
Benchmark results: Inductor: 0.08281274131274131ms, Eager: 0.8217452830188678ms, LN only Inductor: 0.09586902286902287ms.
```

LN + fp8 quant is even faster than LN itself. The reason could be that LN + fp8 outputs fp8 while LN outputs fp16.

From Inductor nightly benchmark test:
There are perf differences in cuda_graph / cuda_graph_dynamic / default runs, but no difference in inductor_max_autotune. So it seems to me that the perf differences are mostly like fluctuations.

![Screenshot 2023-10-18 at 4 58 55 PM](https://github.com/pytorch/pytorch/assets/10527447/6640474a-1e1d-4d33-97e9-0a60d0bc9f1f)

Pull Request resolved: pytorch#111122
Approved by: https://github.com/jansel
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
In pytorch#111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.

Differential Revision: [D50544876](https://our.internmc.facebook.com/intern/diff/D50544876)

Pull Request resolved: pytorch#111781
Approved by: https://github.com/malfet, https://github.com/jansel
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…#112297)

In pytorch#111122, an optimization is introduced for reduction + pointwise + multi-level reduction fusion. The main idea of this optimization is to have the first-level reduction of the multi-level reduction reuses the reduction sizes of the first reduction kernel so that there are better chances that the first reduction kernel and the first-level reduction of the multi-level reduction kernel can be fused. However, it introduces a bug for pattern pointwise + multi-level reduction, where the first-level reduction kernel wrongly reuses the reduction ranges (which is []) from the previous pointwise kernel. This PR fixes this issue.

Test plan:
`python timm_models.py --training --amp --performance --only=dm_nfnet_f0 --inductor`
Results before this PR: 0.869x
Results after this PR: 1.232x

Benchmark results:
![Screenshot 2023-10-30 at 2 30 10 PM](https://github.com/pytorch/pytorch/assets/10527447/c7b241c0-92a4-49ff-96fb-2805c8fcc45a)

<img width="1491" alt="Screenshot 2023-10-30 at 3 10 06 PM" src="https://github.com/pytorch/pytorch/assets/10527447/608d26ea-dcc5-4f2a-8700-4a928701392b">

Pull Request resolved: pytorch#112297
Approved by: https://github.com/jansel
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
In pytorch#111122, an optimization is introduced for reduction() + () + multi-level reduction. In this case, we make a multi-level reduction first-level reduction ranges the same as the previous reduction ranges so that the Inductor has better chances to fuse the first reduction and the first-level reduction of the multi-level reduction kernel together.

There is a corner case that the multi-level reduction kernel has `keepdim=True`. In this case, ranges of the multi-level reduction kernel is not empty, and the dim info needs to be used to create the inner loader of the first-level reduction kernel. To keep the logic simple, for now we simply disable optimization when `keepdim=True`.

Differential Revision: [D50544876](https://our.internmc.facebook.com/intern/diff/D50544876)

Pull Request resolved: pytorch#111781
Approved by: https://github.com/malfet, https://github.com/jansel
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…#112297)

In pytorch#111122, an optimization is introduced for reduction + pointwise + multi-level reduction fusion. The main idea of this optimization is to have the first-level reduction of the multi-level reduction reuses the reduction sizes of the first reduction kernel so that there are better chances that the first reduction kernel and the first-level reduction of the multi-level reduction kernel can be fused. However, it introduces a bug for pattern pointwise + multi-level reduction, where the first-level reduction kernel wrongly reuses the reduction ranges (which is []) from the previous pointwise kernel. This PR fixes this issue.

Test plan:
`python timm_models.py --training --amp --performance --only=dm_nfnet_f0 --inductor`
Results before this PR: 0.869x
Results after this PR: 1.232x

Benchmark results:
![Screenshot 2023-10-30 at 2 30 10 PM](https://github.com/pytorch/pytorch/assets/10527447/c7b241c0-92a4-49ff-96fb-2805c8fcc45a)

<img width="1491" alt="Screenshot 2023-10-30 at 3 10 06 PM" src="https://github.com/pytorch/pytorch/assets/10527447/608d26ea-dcc5-4f2a-8700-4a928701392b">

Pull Request resolved: pytorch#112297
Approved by: https://github.com/jansel
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Nov 19, 2023
…#112297)

In pytorch#111122, an optimization is introduced for reduction + pointwise + multi-level reduction fusion. The main idea of this optimization is to have the first-level reduction of the multi-level reduction reuses the reduction sizes of the first reduction kernel so that there are better chances that the first reduction kernel and the first-level reduction of the multi-level reduction kernel can be fused. However, it introduces a bug for pattern pointwise + multi-level reduction, where the first-level reduction kernel wrongly reuses the reduction ranges (which is []) from the previous pointwise kernel. This PR fixes this issue.

Test plan:
`python timm_models.py --training --amp --performance --only=dm_nfnet_f0 --inductor`
Results before this PR: 0.869x
Results after this PR: 1.232x

Benchmark results:
![Screenshot 2023-10-30 at 2 30 10 PM](https://github.com/pytorch/pytorch/assets/10527447/c7b241c0-92a4-49ff-96fb-2805c8fcc45a)

<img width="1491" alt="Screenshot 2023-10-30 at 3 10 06 PM" src="https://github.com/pytorch/pytorch/assets/10527447/608d26ea-dcc5-4f2a-8700-4a928701392b">

Pull Request resolved: pytorch#112297
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants