Skip to content

Conversation

davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Jun 21, 2023

TL;DR: This PR is a first step in adding lowerings for torch.bucketize. It adds an initial lowering for this op - but because this implementation is not currently efficient, it registers the lowering for prims._inductor_bucketize. After we make the implementation more efficient, we'll remove prims._inductor_bucketize and add the lowering directly to torch.bucketize.

Background - torch.bucketize: torch.bucketize(values, boundaries, right=False): for an arbitrary tensor of values and a non-decreasing 1D tensor of boundaries that define buckets, it returns the index of the bucket that each of the values will fall in. e.g. for values [0, 1, 2, 3, 4] and boundaries [1, 3], it will return [0, 0, 1, 1, 2].

Implementation: This PR adds a new inductor op called "bucketize". In this PR it only has a triton implementation - for CPU it is a fallback. The triton implementation uses a binary search in triton_helpers.py. This PR also adds a new prim _inductor_bucketize() for testing purposes and adds lowering for this op.

"right": The current behavior of the "right" kwarg in the inductor op is the opposite of the behavior of the torch op. "right" controls how the op treats a value that is equal to one of the boundary values. In the torch op, "right=True" means "if a value is equal to a boundary value, then put it in the bucket to the right". In the inductor op, "right=True" means "the right boundary of a bucket is closed". These are opposite. I'm open to switching the behavior of the inductor op - but I chose to implement this way because I think it makes more sense, and I think the torch.bucketize behavior may have been a mistake (it's the opposite of numpy.digitize). Switched the behavior of the inductor bucketize op to match the torch op

Performance: Benchmark script: "values" as a [16, 1024, 1024] float32 tensor and "boundaries" as a [1025] tensor (i.e. defining 1024 buckets).

As is:

Eager 0.30117499828338623 ms
PT2   0.9298200011253357 ms

But performance improves significantly if we add an additional pointwise autotuning config (WIP in #104456):

Eager 0.3015420138835907 ms
PT2   0.23028500378131866 ms

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104007

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2cc9c94:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98 davidberard98 force-pushed the bucket_index branch 6 times, most recently from 0d1b96a to cb4be47 Compare June 29, 2023 00:25
Comment on lines 858 to 859
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this, indeed, seems to correspond to the semantics of torch.bucketize, in the jagged tensor context, seems that we'll need to subtract 1 from the result (as the values falling into the first bucket should have index 0).

Copy link
Contributor Author

@davidberard98 davidberard98 Jun 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some rough benchmarks I didn't see any measurable difference in perf from subtracting the 1 vs. not having to do that (in fbgemm lowerings). But good point, we'll need to make sure not to forget this

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98 davidberard98 changed the title [WIP][inductor] add prims.inductor_bucket_index and inductor lowerings [WIP][inductor] Add prims._inductor_bucketize and add lowerings Jun 30, 2023
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

pytorchmergebot pushed a commit that referenced this pull request Jun 30, 2023
torch.bucketize takes a tensor of values, and a "boundaries" tensor, which is a sorted list of values that represent buckets. It returns the bucket that each value lies in. E.g. if values = [1, 5, 3, 6] and boundaries=[0, 2, 4, 6, 8], the output will be [1, 3, 2, 4].

The current decomposition of this op doesn't work well with dynamic shapes. It performs a binary search, which bakes in the number of iterations in the binary search and requires recompiling (I don't completely understand why/where this happens). I can't think if whether there's a good way to write a decomposition for this op that will work with dynamic shapes.

Use case: this op is very similar to some operations needed by jagged tensors. As a first step, I want to add a lowering for aten.bucketize and make use of opinfos. #104007
Pull Request resolved: #104396
Approved by: https://github.com/Chillee
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98 davidberard98 requested a review from eellison June 30, 2023 05:10
@davidberard98 davidberard98 marked this pull request as ready for review June 30, 2023 05:10
@davidberard98 davidberard98 changed the title [WIP][inductor] Add prims._inductor_bucketize and add lowerings [inductor] Add prims._inductor_bucketize and add lowerings Jun 30, 2023
):
assert len(boundaries.get_size()) == 1

if input.get_device().type != "cuda" or boundaries.get_device().type != "cuda":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the is_triton() helper (which does the same thing)

input_loader = input.make_loader()

index_dtype = torch.int32 if out_int32 else torch.int64
triton_dtype = "tl.int32" if out_int32 else "tl.int64"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not expose the triton_dytpe in the device-agnostic IR. This will make it harder to extend to non-triton backends.

boundaries.get_name(),
ops.index_expr(boundaries_size, index_dtype),
triton_dtype,
not right, # ops.bucketize and torch.bucketize have opposite semantics for "right"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial feelings here are we should match the torch semantics rather than the numpy one. That will at least make our codebase self-consistent. I don't feel that strongly about that though.

1. use is_triton()
2. pass a torch.dtype instead of a triton dtype into the bucketize inductor op
3. Switch the behavior of "right" back to torch semantics
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98 davidberard98 requested a review from jansel June 30, 2023 20:05
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

davidberard98 added a commit that referenced this pull request Jul 4, 2023
… where num_elements_per_warp=32"


In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose.

I benchmarked #104007 with and without this change on a 16MB pointwise kernel. This change reduces the latency from 1ms to 0.35ms.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 4, 2023
…nts_per_warp=32"


In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose.

I benchmarked #104007 with and without this change on a 16MB pointwise kernel. This change reduces the latency from 1ms to 0.35ms.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 5, 2023
… try using config with num_elements_per_warp=32"


In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning 

Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. 

Before:
```
Eager 0.30088499188423157 ms
PT2   0.9296960234642029 ms
```

After:
```
Eager 0.3011910021305084 ms
PT2   0.22977299988269806 ms
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 5, 2023
…g with num_elements_per_warp=32"


In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning 

Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. 

Before:
```
Eager 0.30088499188423157 ms
PT2   0.9296960234642029 ms
```

After:
```
Eager 0.3011910021305084 ms
PT2   0.22977299988269806 ms
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 6, 2023
… try using config with num_elements_per_warp=32"


In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning 

Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. 

Before:
```
Eager 0.30088499188423157 ms
PT2   0.9296960234642029 ms
```

After:
```
Eager 0.3011910021305084 ms
PT2   0.22977299988269806 ms
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103)

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jul 6, 2023
…g with num_elements_per_warp=32"


In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning 

Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5. 

Before:
```
Eager 0.30088499188423157 ms
PT2   0.9296960234642029 ms
```

After:
```
Eager 0.3011910021305084 ms
PT2   0.22977299988269806 ms
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jul 7, 2023
…elements_per_warp=32 (#104456)

In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.

This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning

Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5.

Before:
```
Eager 0.30088499188423157 ms
PT2   0.9296960234642029 ms
```

After:
```
Eager 0.3011910021305084 ms
PT2   0.22977299988269806 ms
```

Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103)
Pull Request resolved: #104456
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants