-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[inductor] Add prims._inductor_bucketize and add lowerings #104007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104007
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2cc9c94: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
5c59c46
to
b36bfbb
Compare
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
0d1b96a
to
cb4be47
Compare
torch/_inductor/codegen/common.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although this, indeed, seems to correspond to the semantics of torch.bucketize
, in the jagged tensor context, seems that we'll need to subtract 1 from the result (as the values falling into the first bucket should have index 0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some rough benchmarks I didn't see any measurable difference in perf from subtracting the 1 vs. not having to do that (in fbgemm lowerings). But good point, we'll need to make sure not to forget this
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
95d13d7
to
8abfc45
Compare
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
torch.bucketize takes a tensor of values, and a "boundaries" tensor, which is a sorted list of values that represent buckets. It returns the bucket that each value lies in. E.g. if values = [1, 5, 3, 6] and boundaries=[0, 2, 4, 6, 8], the output will be [1, 3, 2, 4]. The current decomposition of this op doesn't work well with dynamic shapes. It performs a binary search, which bakes in the number of iterations in the binary search and requires recompiling (I don't completely understand why/where this happens). I can't think if whether there's a good way to write a decomposition for this op that will work with dynamic shapes. Use case: this op is very similar to some operations needed by jagged tensors. As a first step, I want to add a lowering for aten.bucketize and make use of opinfos. #104007 Pull Request resolved: #104396 Approved by: https://github.com/Chillee
4d9fc32
to
1b993ec
Compare
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
torch/_inductor/lowering.py
Outdated
): | ||
assert len(boundaries.get_size()) == 1 | ||
|
||
if input.get_device().type != "cuda" or boundaries.get_device().type != "cuda": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the is_triton() helper (which does the same thing)
torch/_inductor/lowering.py
Outdated
input_loader = input.make_loader() | ||
|
||
index_dtype = torch.int32 if out_int32 else torch.int64 | ||
triton_dtype = "tl.int32" if out_int32 else "tl.int64" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not expose the triton_dytpe in the device-agnostic IR. This will make it harder to extend to non-triton backends.
torch/_inductor/lowering.py
Outdated
boundaries.get_name(), | ||
ops.index_expr(boundaries_size, index_dtype), | ||
triton_dtype, | ||
not right, # ops.bucketize and torch.bucketize have opposite semantics for "right" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My initial feelings here are we should match the torch semantics rather than the numpy one. That will at least make our codebase self-consistent. I don't feel that strongly about that though.
1. use is_triton() 2. pass a torch.dtype instead of a triton dtype into the bucketize inductor op 3. Switch the behavior of "right" back to torch semantics
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… where num_elements_per_warp=32" In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values. This PR adds an autotuning config option for this purpose. I benchmarked #104007 with and without this change on a 16MB pointwise kernel. This change reduces the latency from 1ms to 0.35ms. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 [ghstack-poisoned]
…nts_per_warp=32" In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values. This PR adds an autotuning config option for this purpose. I benchmarked #104007 with and without this change on a 16MB pointwise kernel. This change reduces the latency from 1ms to 0.35ms. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 [ghstack-poisoned]
… try using config with num_elements_per_warp=32" In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values. This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. Before: ``` Eager 0.30088499188423157 ms PT2 0.9296960234642029 ms ``` After: ``` Eager 0.3011910021305084 ms PT2 0.22977299988269806 ms ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 [ghstack-poisoned]
…g with num_elements_per_warp=32" In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values. This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. Before: ``` Eager 0.30088499188423157 ms PT2 0.9296960234642029 ms ``` After: ``` Eager 0.3011910021305084 ms PT2 0.22977299988269806 ms ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 [ghstack-poisoned]
… try using config with num_elements_per_warp=32" In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values. This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. Before: ``` Eager 0.30088499188423157 ms PT2 0.9296960234642029 ms ``` After: ``` Eager 0.3011910021305084 ms PT2 0.22977299988269806 ms ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103) [ghstack-poisoned]
…g with num_elements_per_warp=32" In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values. This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. Before: ``` Eager 0.30088499188423157 ms PT2 0.9296960234642029 ms ``` After: ``` Eager 0.3011910021305084 ms PT2 0.22977299988269806 ms ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103) [ghstack-poisoned]
…elements_per_warp=32 (#104456) In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values. This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. Before: ``` Eager 0.30088499188423157 ms PT2 0.9296960234642029 ms ``` After: ``` Eager 0.3011910021305084 ms PT2 0.22977299988269806 ms ``` Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103) Pull Request resolved: #104456 Approved by: https://github.com/eellison
TL;DR: This PR is a first step in adding lowerings for torch.bucketize. It adds an initial lowering for this op - but because this implementation is not currently efficient, it registers the lowering for prims._inductor_bucketize. After we make the implementation more efficient, we'll remove prims._inductor_bucketize and add the lowering directly to torch.bucketize.
Background - torch.bucketize: torch.bucketize(values, boundaries, right=False): for an arbitrary tensor of values and a non-decreasing 1D tensor of boundaries that define buckets, it returns the index of the bucket that each of the values will fall in. e.g. for values [0, 1, 2, 3, 4] and boundaries [1, 3], it will return [0, 0, 1, 1, 2].
Implementation: This PR adds a new inductor op called "bucketize". In this PR it only has a triton implementation - for CPU it is a fallback. The triton implementation uses a binary search in
triton_helpers.py
. This PR also adds a new prim_inductor_bucketize()
for testing purposes and adds lowering for this op."right": The current behavior of the "right" kwarg in the inductor op is the opposite of the behavior of the torch op. "right" controls how the op treats a value that is equal to one of the boundary values. In the torch op, "right=True" means "if a value is equal to a boundary value, then put it in the bucket to the right". In the inductor op, "right=True" means "the right boundary of a bucket is closed". These are opposite. I'm open to switching the behavior of the inductor op - but I chose to implement this way because I think it makes more sense, and I think the torch.bucketize behavior may have been a mistake (it's the opposite of numpy.digitize).Switched the behavior of the inductor bucketize op to match the torch opbucketize()
is wrong, and is contradicting itself. #91580)Performance: Benchmark script: "values" as a [16, 1024, 1024] float32 tensor and "boundaries" as a [1025] tensor (i.e. defining 1024 buckets).
As is:
But performance improves significantly if we add an additional pointwise autotuning config (WIP in #104456):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78