Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix aliasing for primtorch view meta kernels #86285

Closed
wants to merge 13 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 5, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86285

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4b7afaa:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mruberry
Copy link
Collaborator

mruberry commented Oct 5, 2022

The approach looks reasonable to me. What about extending test_python_ref_meta to detect this:

def test_python_ref_meta(self, device, dtype, op):

@IvanYashchuk
Copy link
Collaborator

Is the only difference here that with alias=True the resulting tensor would return True for tensor._is_view()?

import torch
a = torch.zeros(3, 2, device="meta")
b = torch._prims.TensorMeta(a, shape=a.shape, strides=a.stride())
print(b._is_view()) # False
c = a.as_strided(a.shape, a.stride())
print(c._is_view()) # True
assert a.storage().data_ptr() == b.storage().data_ptr() == c.storage().data_ptr() # It's all 0 for meta

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 6, 2022

@IvanYashchuk since the output isn't a view of the input, then its metadata can also be incorrect. For example if the input is some strided tensor with a non-zero storage offset, then that won't get propagated to the output:

>>> a = torch.ones(2, 2, device='meta')[1]
>>> a.storage_offset()
2
>>> out_aten = a.as_strided(a.shape, a.stride())
>>> out_aten.storage_offset() # prints 2, the same as the input
2
>>> out_prim = torch._prims.TensorMeta(a, shape=a.shape, strides=a.stride())
>>> out_prim.storage_offset() # should print 2, prints 0!
0
>>>

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 6, 2022

@mruberry Thanks for the pointer to test_python_ref_meta. It looks like today, that test currently checks the output shapes match, but not strides or storage offsets. The two ways I could imagine updating that test are:

(1) Update the test to check for the _is_view() relationship (this seems a bit weird since it's a private API that's dependent on autograd, but maybe this would be fine?)
(1) updating the test to always check that storage offsets match. This is the "broken" thing that I originally noticed. That would mean that the tests always check that sizes and storage_offset match, but no strides. Does that sound reasonable to you?

FWIW, I'm also adding a cross ref test that tests correctness of all of the python decomps and meta functions in torch/_decomp/decompositions.py and torch/_meta_registrations.py. Some of those decomps call into prims, which is where the problem first showed up. Although those two files probably don't comprehensively call into the prims, so a separate prim test sounds good to me.

@mruberry
Copy link
Collaborator

mruberry commented Oct 6, 2022

@mruberry Thanks for the pointer to test_python_ref_meta. It looks like today, that test currently checks the output shapes match, but not strides or storage offsets. The two ways I could imagine updating that test are:

(1) Update the test to check for the _is_view() relationship (this seems a bit weird since it's a private API that's dependent on autograd, but maybe this would be fine?)

Maybe TestViewOps in test_view_ops.py can provide some inspiration. See

def is_view_of(self, base, other):

(1) updating the test to always check that storage offsets match. This is the "broken" thing that I originally noticed. That would mean that the tests always check that sizes and storage_offset match, but no strides. Does that sound reasonable to you?

Stride-testing has been disabled for the moment. I spent a good amount of time trying to emulate PyTorch's striding logic, but PyTorch is pretty inconsistent in how it handles strides itself. See #78050. @ezyang concludes that we'd like to emulate strides, but I'm not sure how we would do this.

FWIW, I'm also adding a cross ref test that tests correctness of all of the python decomps and meta functions in torch/_decomp/decompositions.py and torch/_meta_registrations.py. Some of those decomps call into prims, which is where the problem first showed up. Although those two files probably don't comprehensively call into the prims, so a separate prim test sounds good to me.

There are several existing tests in test_ops.py for Python reference consistency, see

def test_python_ref(self, device, dtype, op):

These might be interesting to look at when developing a consistency test for decompositions. Ideally I think we'd like to see all decompositions ported to become Python references.

@ezyang
Copy link
Contributor

ezyang commented Oct 6, 2022

@bdhirsh, why don't you just convert these into direct as_strided calls on the input tensor, rather than going through the TensorMeta constructor?

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 6, 2022

@bdhirsh, why don't you just convert these into direct as_strided calls on the input tensor, rather than going through the TensorMeta constructor?

Sounds good - you're right, seems cleaner

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 10, 2022

@mruberry Updated the PR - I had some issues trying to test the _is_view() property. Since that's not really the property that I cared about in the first place (storage offset was incorrect), I updated the op info tests to check for storage offset, which I confirmed exercises the problem from #86284. Let me know if you're happy with it

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The storage_offset argument is missing. Sounds good otherwise.

@@ -1173,7 +1173,7 @@ def _greater_than_reduce(acc, x):
else:
new_strides.append(0)

return TensorMeta(a, shape=shape, strides=new_strides)
return a.as_strided(shape, new_strides)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the functions here need a storage_offset=a.storage_offset() argument!

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 11, 2022
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants