-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix aliasing for primtorch view meta kernels #86285
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86285
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4b7afaa: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The approach looks reasonable to me. What about extending Line 174 in 17addb3
|
Is the only difference here that with import torch
a = torch.zeros(3, 2, device="meta")
b = torch._prims.TensorMeta(a, shape=a.shape, strides=a.stride())
print(b._is_view()) # False
c = a.as_strided(a.shape, a.stride())
print(c._is_view()) # True
assert a.storage().data_ptr() == b.storage().data_ptr() == c.storage().data_ptr() # It's all 0 for meta |
@IvanYashchuk since the output isn't a view of the input, then its metadata can also be incorrect. For example if the input is some strided tensor with a non-zero storage offset, then that won't get propagated to the output:
|
@mruberry Thanks for the pointer to (1) Update the test to check for the FWIW, I'm also adding a cross ref test that tests correctness of all of the python decomps and meta functions in |
Maybe Line 96 in cebf08a
Stride-testing has been disabled for the moment. I spent a good amount of time trying to emulate PyTorch's striding logic, but PyTorch is pretty inconsistent in how it handles strides itself. See #78050. @ezyang concludes that we'd like to emulate strides, but I'm not sure how we would do this.
There are several existing tests in Line 339 in cebf08a
These might be interesting to look at when developing a consistency test for decompositions. Ideally I think we'd like to see all decompositions ported to become Python references. |
@bdhirsh, why don't you just convert these into direct |
Sounds good - you're right, seems cleaner |
Fixes #86284 [ghstack-poisoned]
Fixes #86284 [ghstack-poisoned]
@mruberry Updated the PR - I had some issues trying to test the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The storage_offset argument is missing. Sounds good otherwise.
torch/_prims/__init__.py
Outdated
@@ -1173,7 +1173,7 @@ def _greater_than_reduce(acc, x): | |||
else: | |||
new_strides.append(0) | |||
|
|||
return TensorMeta(a, shape=shape, strides=new_strides) | |||
return a.as_strided(shape, new_strides) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the functions here need a storage_offset=a.storage_offset()
argument!
Fixes #86284 [ghstack-poisoned]
Fixes #86284 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
Fixes #86284 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Fixes #86284
Stack from ghstack (oldest at bottom):