Skip to content

[inductor] Generalize pointless_cumsum_replacement pattern #108373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Aug 31, 2023

Stack from ghstack (oldest at bottom):

The current pattern transforms:

ones([x, y]).cumsum(1) -> arange(1, 1 + y).expand([x, y])

but this generalizes it to

full(shape, fill_value).cumsum(d) ->
    (fill_value * arange(1, 1 + shape[d])).view([1..., shape[d], 1...]).expand(shape)

So we handle any fill value, any number of dimensions, and broadcasting to any dimension.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

The current pattern transforms:
```
ones([x, y]).cumsum(1) -> arange(1, 1 + y).expand([x, y])
```
but this generalizes it to
```
full(shape, fill_value).cumsum(d) ->
    (fill_value * arange(1, 1 + shape[d])).view([1..., d, 1...]).expand(shape)
```

So we handle any fill value, any number of dimensions, and broadcasting to any dimension.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108373

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 948fddb with merge base 0e4752b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

peterbell10 added a commit that referenced this pull request Aug 31, 2023
The current pattern transforms:
```
ones([x, y]).cumsum(1) -> arange(1, 1 + y).expand([x, y])
```
but this generalizes it to
```
full(shape, fill_value).cumsum(d) ->
    (fill_value * arange(1, 1 + shape[d])).view([1..., d, 1...]).expand(shape)
```

So we handle any fill value, any number of dimensions, and broadcasting to any dimension.

ghstack-source-id: 0a4bc2a
Pull Request resolved: #108373
@peterbell10 peterbell10 marked this pull request as ready for review September 1, 2023 13:48
@peterbell10 peterbell10 requested a review from lezcano September 1, 2023 13:48
@peterbell10 peterbell10 added the topic: not user facing topic category label Sep 1, 2023
@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 1, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/607/head branch September 5, 2023 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants