Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't fallback for aten.index when there's a None in the middle of the indexing. #110711

Closed
Chillee opened this issue Oct 6, 2023 · 0 comments
Closed
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Chillee
Copy link
Contributor

Chillee commented Oct 6, 2023

🚀 The feature, motivation and pitch

import torch
torch.set_default_device('cuda')

torch.manual_seed(420)

x = torch.arange(3*4*5).view(3, 4, 5)

def f(x):
    return x[torch.tensor([1, 2]), :, torch.tensor([2, 3])]

print(f(x))
print(torch.compile(f)(x))
>>>
tensor([[22, 27, 32, 37],
        [43, 48, 53, 58]], device='cuda:0')

This gets lowered as a fallback kernel due to this check (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py#L2567). Let's handle that case correctly.

The semantics of having a None there are identical to having a torch.arange there and broadcasting the rest of the indices appropriately.

def f(x):
    return x[torch.tensor([1, 2]).unsqueeze(-1), torch.arange(4), torch.tensor([2, 3]).unsqueeze(-1)]

Alternatives

No response

Additional context

No response

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@oulgen oulgen self-assigned this Oct 6, 2023
@ezyang ezyang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: inductor labels Oct 6, 2023
oulgen added a commit that referenced this issue Oct 11, 2023
oulgen added a commit that referenced this issue Oct 11, 2023
…r aten.index"

Fixes #110711

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this issue Oct 11, 2023
…r aten.index"

Fixes #110711

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this issue Oct 11, 2023
…r aten.index"

Fixes #110711

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this issue Oct 11, 2023
…r aten.index"

Fixes #110711

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this issue Oct 11, 2023
Fixes #110711

ghstack-source-id: fa80c24abde1cdec4c126de71f8885120fb81fcf
Pull Request resolved: #111015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants