Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsqueeze and expand metas produce wrong strides #90260

Closed
ngimel opened this issue Dec 6, 2022 · 1 comment
Closed

unsqueeze and expand metas produce wrong strides #90260

ngimel opened this issue Dec 6, 2022 · 1 comment
Labels
high priority module: correctness (silent) issue that returns an incorrect result silently module: primTorch module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ngimel
Copy link
Collaborator

ngimel commented Dec 6, 2022

In [2]: a=torch.randn(1,10)

In [3]: torch._refs.unsqueeze(a, 1).stride()
Out[3]: (10, 10, 1)

In [4]: a=torch.randn(1,10, device="meta")

In [5]: torch._refs.unsqueeze(a, 1).stride()
Out[5]: (10, 0, 1)

Same for torch.expand:

In [10]: a.expand(1, 1, 1, 10).stride()
Out[10]: (10, 10, 10, 1)

In [11]: torch._refs.expand(a, 1, 1, 1, 10).stride()
Out[11]: (10, 10, 10, 1)

In [12]: a=torch.randn(1, 1, 10, device="meta")

In [13]: torch._refs.expand(a, 1, 1, 1, 10).stride()
Out[13]: (0, 10, 10, 1)

Also, expand_dims should not be in torch/_prims, it's not a prim.

cc @ezyang @gchanan @zou3519 @kadeng @mruberry @lezcano @peterbell10 @msaroufim @bdhirsh @anijain2305 @chauhang @fdrocha @soumith

@ngimel ngimel changed the title unsqueeze meta produces wrong strides unsqueeze and expand metas produce wrong strides Dec 6, 2022
@jbschlosser jbschlosser added high priority triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: correctness (silent) issue that returns an incorrect result silently oncall: pt2 labels Dec 6, 2022
pytorchmergebot pushed a commit that referenced this issue Dec 7, 2022
Fixes pytorch/torchdynamo#1959, #90260
However, I wasn't able to make existing stride tests fail before the fix, even though I'm comparing all, not just significant strides.
Separately running refs on meta tensors produces wrong strides as shown in #90260, however, it looks like in meta tests some other way of computing meta info is used (I've been running
```
pytest -s -v test/test_meta.py -k test_meta_outplace_expand_cuda_float64
```
and verified that it has sample input that should fail, and that it indeed compares all the strides, but the produced `meta_rs` results somehow still had correct strides).

Edit: @SherlockNoMad helped me figure out how to fail the tests, and now I've set the correct ops for checking. `expand` fails for some test inputs because it special-cases 0-dim input case, correctly modeling it in prims would require a lot of changes, so skipping that for now.

Pull Request resolved: #90341
Approved by: https://github.com/SherlockNoMad
@soumith
Copy link
Member

soumith commented Dec 8, 2022

fixed via #90341

@soumith soumith closed this as completed Dec 8, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
Fixes pytorch/torchdynamo#1959, pytorch#90260
However, I wasn't able to make existing stride tests fail before the fix, even though I'm comparing all, not just significant strides.
Separately running refs on meta tensors produces wrong strides as shown in pytorch#90260, however, it looks like in meta tests some other way of computing meta info is used (I've been running
```
pytest -s -v test/test_meta.py -k test_meta_outplace_expand_cuda_float64
```
and verified that it has sample input that should fail, and that it indeed compares all the strides, but the produced `meta_rs` results somehow still had correct strides).

Edit: @SherlockNoMad helped me figure out how to fail the tests, and now I've set the correct ops for checking. `expand` fails for some test inputs because it special-cases 0-dim input case, correctly modeling it in prims would require a lot of changes, so skipping that for now.

Pull Request resolved: pytorch#90341
Approved by: https://github.com/SherlockNoMad
@zou3519 zou3519 added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Mar 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: correctness (silent) issue that returns an incorrect result silently module: primTorch module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants