Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] make thread order consistent with loop order #106827

Closed
wants to merge 3 commits into from

Conversation

shunting314
Copy link
Contributor

@shunting314 shunting314 commented Aug 8, 2023

Stack from ghstack (oldest at bottom):

I found that for a tiled kernel for tensor with shape [a, b], we map 'a' with XBLOCK and 'b' with YBLOCK. However, 'a' actually should be the outer looper while 'b' corresponding to the inner loop. This order is picked by our loop ordering algorithm. Mapping 'a' with XBLOCK has the semantic like assigning 'a' to the inner loop instead.

For a simple 'A + B.t()' kernel, making the loop order consistent can brings 1.027x speedup ( 1.938ms -> 1.887ms speedup) . Here are the dump of kernels:

I tried this on DistillGPT2 and found perf is neutral. But that because DistillGPT2 has a single tiled pointwise kernel in it's backward graph. Will check the dashboard.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106827

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 88a4ea1:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Aug 8, 2023
ghstack-source-id: 92cbccde98c5e74b363fef0cdd9eec2bdaa1a78a
Pull Request resolved: #106827
@shunting314
Copy link
Contributor Author

shunting314 commented Aug 9, 2023

Perf run: link

So looks like the PR helps with some models but also slow down others. My guess is some kernel's loop order makes the behavior without the fix running faster. Also I'm thinking maybe we can update the existing algorithm that picks loop order a bit. The current algorithm will do the following

  1. if all strides for dim-a is smaller than all strides for dim-b, we put dim-a later in the order ( this basically describes the order reversing return value of scheduler.pick_loop_order since that's what we usually do for the return value).
  2. if for some indices the strides of dim-a is smaller than dim-b, but for other indices, it's the opposite, we put the dimension with smaller index first in the order.

Maybe we can change heuristic 2 to consider how many indices that dim-a has smaller strides. i.e. define X as the number of indices that dim-a has smaller strides than dim-b and Y as the number of indices that dim-b has smaller strides. If X - Y is larger than a threshold, than we decide to put dim-a later in the order.

I found that for a tiled kernel for tensor with shape [a, b], we map 'a' with XBLOCK and 'b' with YBLOCK. However, 'a' actually should be the outer looper while 'b' corresponding to the inner loop. This order is picked by our loop ordering algorithm. Mapping 'a' with XBLOCK has the semantic like assigning 'a' to the inner loop instead.

For a simple 'A + B.t()' kernel, making the loop order consistent can brings 1.027x speedup ( 1.938ms -> 1.887ms speedup) . Here are the dump of kernels:

- before fix: https://gist.github.com/shunting314/4dacf73cf495cdd7e84dede7c3e0872d 
- after fix (this one is done manually): https://gist.github.com/shunting314/441e8839d24e1878c313e539b1ebd551 

I tried this on DistillGPT2 and found perf is neutral. But that because DistillGPT2 has a single tiled pointwise kernel in it's backward graph. Will check the dashboard.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Aug 9, 2023
ghstack-source-id: 0fbf66f10b6bb008779ec51b26313be25ea1a3e8
Pull Request resolved: #106827
I found that for a tiled kernel for tensor with shape [a, b], we map 'a' with XBLOCK and 'b' with YBLOCK. However, 'a' actually should be the outer looper while 'b' corresponding to the inner loop. This order is picked by our loop ordering algorithm. Mapping 'a' with XBLOCK has the semantic like assigning 'a' to the inner loop instead.

For a simple 'A + B.t()' kernel, making the loop order consistent can brings 1.027x speedup ( 1.938ms -> 1.887ms speedup) . Here are the dump of kernels:

- before fix: https://gist.github.com/shunting314/4dacf73cf495cdd7e84dede7c3e0872d 
- after fix (this one is done manually): https://gist.github.com/shunting314/441e8839d24e1878c313e539b1ebd551 

I tried this on DistillGPT2 and found perf is neutral. But that because DistillGPT2 has a single tiled pointwise kernel in it's backward graph. Will check the dashboard.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Aug 11, 2023
ghstack-source-id: 4e777e2ee69fec9f9004a47de714c8e1b5a588a5
Pull Request resolved: #106827
@shunting314
Copy link
Contributor Author

Here is the latest perf run .

  • TIMM: 1.65x -> 1.69x
  • HF: neutral
  • TB: 1.75x -> 1.74x

So overall I think this PR is helpful especially on TIMM models.

@shunting314
Copy link
Contributor Author

@pytorchbot merge -f "keyboard interrupt and test failure on functorch are unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@shunting314
Copy link
Contributor Author

I just want to share more why this PR is helpful with the perf. Hopefully the deeper understanding can help me find future optimization opportunities.

Here is a micro-benchmark:

import torch
from torch._inductor import config
from torch._inductor.utils import do_bench

config.benchmark_kernel = True
config.triton.unique_kernel_names = True

def f(*inp):
    x = inp[0]
    for a in inp[1:]:
        x = x + a.permute(0, 2, 1)

    return x.sum(dim=0)


K = 4
inp = [torch.randn(512, 512, 512).cuda() for _ in range(K)]
opt_f = torch.compile(f)
opt_f(*inp)
print(do_bench(lambda: opt_f(*inp)))

The kernel does an outer reduction. Inside the reduction loop, all but one of the memory access prefer dim-1 as the inner most pointwise loop. The PR just make sure we pick dim-1 rather than dim-2 as the innermost pointwise loop in this case. This turns more memory access to coalesced access (with the help of tiling).

Metric and generated wrapper:

cc @jansel @Chillee @eellison

@facebook-github-bot facebook-github-bot deleted the gh/shunting314/72/head branch August 15, 2023 14:17
jackiexu1992 added a commit to jackiexu1992/pytorch that referenced this pull request Aug 23, 2023
…torch#106827)"

Summary: D48295371 cause batch fusion failure

Test Plan: Without revert, f469732293. With revert diff f472266199.

Differential Revision: D48593029

fbshipit-source-id: 1298c765954082bbfc8478678be6a03f950db781
jackiexu1992 added a commit to jackiexu1992/pytorch that referenced this pull request Aug 23, 2023
…torch#106827)"

Summary: D48295371 cause batch fusion failure, which will block our mc proposal on all mc models.

Test Plan: Without revert, f469732293. With revert diff f472266199.

Reviewed By: yanboliang

Differential Revision: D48593029

fbshipit-source-id: 751a3f6b20e51b728044852a4a5fd3a376529cce
jackiexu1992 added a commit to jackiexu1992/pytorch that referenced this pull request Aug 23, 2023
…torch#106827)"

Summary:
D48295371 cause batch fusion failure, which will block mc proposals on all mc models.
e.g. cmf f470938179

Test Plan: Without revert, f469732293. With revert diff f472266199.

Differential Revision: D48610062

fbshipit-source-id: 5f0a1fbd2b5fcf6f0d21a42712691ba51906d20e
pytorchmergebot pushed a commit that referenced this pull request Aug 23, 2023
…06827)" (#107796)

Summary:
D48295371 cause batch fusion failure, which will block mc proposals on all mc models.
e.g. cmf f470938179

Test Plan: Without revert, f469732293. With revert diff f472266199.

Differential Revision: D48610062

Pull Request resolved: #107796
Approved by: https://github.com/yanboliang
shunting314 added a commit that referenced this pull request Aug 24, 2023
…order"


This PR relands #106827 which get reverted because of causing compilation error for some ads model.

Yanbo provide a repro in one of the 14k model ( `pytest ./generated/test_KaiyangZhou_deep_person_reid.py -k test_044`). This is also the model I used to confirm the fix and come up with a unit test. In this model, we call `tritoin_heuristics.triton_config` with size_hints [2048, 2]. Previously this would result in a trition config with XBLOCK=2048 and YBLOCK=2 . But since we change the mapping between size_hints and XYZ dimension, we now generate a triton config with XBLOCK=2 and YBLOCK=2048.  This fails compilation since we set max YBLOCK to be 1024.

My fix is to make sure we never generate a triton config that exceeds the maximum block size.


[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Aug 25, 2023
…order"


This PR relands #106827 which get reverted because of causing compilation error for some ads model.

Yanbo provide a repro in one of the 14k model ( `pytest ./generated/test_KaiyangZhou_deep_person_reid.py -k test_044`). This is also the model I used to confirm the fix and come up with a unit test. In this model, we call `tritoin_heuristics.triton_config` with size_hints [2048, 2]. Previously this would result in a trition config with XBLOCK=2048 and YBLOCK=2 . But since we change the mapping between size_hints and XYZ dimension, we now generate a triton config with XBLOCK=2 and YBLOCK=2048.  This fails compilation since we set max YBLOCK to be 1024.

My fix is to make sure we never generate a triton config that exceeds the maximum block size.


[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 26, 2023
)

This PR relands #106827 which get reverted because of causing compilation error for some ads model.

Yanbo provide a repro in one of the 14k model ( `pytest ./generated/test_KaiyangZhou_deep_person_reid.py -k test_044`). This is also the model I used to confirm the fix and come up with a unit test. In this model, we call `tritoin_heuristics.triton_config` with size_hints [2048, 2]. Previously this would result in a trition config with XBLOCK=2048 and YBLOCK=2 . But since we change the mapping between size_hints and XYZ dimension, we now generate a triton config with XBLOCK=2 and YBLOCK=2048.  This fails compilation since we set max YBLOCK to be 1024.

My fix is to make sure we never generate a triton config that exceeds the maximum block size.

Pull Request resolved: #107902
Approved by: https://github.com/jansel
voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
)

This PR relands #106827 which get reverted because of causing compilation error for some ads model.

Yanbo provide a repro in one of the 14k model ( `pytest ./generated/test_KaiyangZhou_deep_person_reid.py -k test_044`). This is also the model I used to confirm the fix and come up with a unit test. In this model, we call `tritoin_heuristics.triton_config` with size_hints [2048, 2]. Previously this would result in a trition config with XBLOCK=2048 and YBLOCK=2 . But since we change the mapping between size_hints and XYZ dimension, we now generate a triton config with XBLOCK=2 and YBLOCK=2048.  This fails compilation since we set max YBLOCK to be 1024.

My fix is to make sure we never generate a triton config that exceeds the maximum block size.

Pull Request resolved: #107902
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants