-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make numel equal test size oblivious in reshape_symint #124611
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124611
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a72c0bd with merge base 7cd7a7a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
aten/src/ATen/InferSize.h
Outdated
@@ -37,7 +37,7 @@ inline void infer_size_impl( | |||
} | |||
} | |||
|
|||
if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) { | |||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) || (infer_dim && newsize > 0 && numel % newsize == 0)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, one question about size-oblivious guards:
So far I've thought their main use case as being for when we have some pytorch code that diverges on two cases (one for the 0/1 case and one for the general case), and we've decided that the general case is safe to use all the time when our size variable is unbacked. The example in my head being that squeeze()
will repeatedly remove consecutive dims of size 1, which we will ignore when our dims are unbacked.
Here though, every single size check in this function is effectively an error check (if the sizes don't match then our reshape should error).
So size oblivious guards don't seem wrong here (the "correct" path is probably not to error). But do you think that TORCH_SYM_CHECK
is more the right tool for this case (which basically says "always take the non-erroring path, but add runtime asserts for it)? Or maybe I'm holding this wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right that a deferred runtime assert might be even better. But it's nontrivial to convert this into a torch._check
because of the disjunction. It may be possible, I just went for the simplest thing in this PR.
@pytorchbot merge -i |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#124581 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#124611 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#124139
Fixes pytorch#124581 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#124611 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#124139
Fixes pytorch#124581 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#124611 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#124139
Stack from ghstack (oldest at bottom):
Fixes #124581
Signed-off-by: Edward Z. Yang ezyang@meta.com