New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use expect_true to make split with unbacked sizes work. #106788
Conversation
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106788
Note: Links to docs will display an error until the docs builds have been completed. ✅ 1 Unrelated FailureAs of commit 0526cb2: UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: dcdfbb704ed64b9a95b02157d41fb55bd383c1f4 Pull Request resolved: #106788
start = maybe_wrap_dim(start, cur_size); | ||
TORCH_SYM_CHECK((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size)), "narrow(): start must be within [-cur_size, cur_size]"); | ||
if (start < 0) { | ||
start = start + cur_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The start != cur_size test cannot be sym'ified since one branch is not an error condition. What I did here was inlined maybe_wrap_dim
here, which conventionally tests -cur_size <= start < cur_size
; so you can make it work without extra branching by just changing this condition to -cur_size <= start <= cur_size
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can split out the inlining into its own PR if people would prefer.
@@ -948,7 +948,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab | |||
if not isinstance(cond, (builtins.bool, torch.SymBool)): | |||
raise TypeError(f'cond must be a bool, but got {type(cond)}') | |||
|
|||
if cond: | |||
if torch.fx.experimental.symbolic_shapes.expect_true(cond): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This turns on expect_true for all our decomps/meta functions, pretty nice. (I couldn't do the same trick in C++; TORCH_CHECK is a very delicate macro and it was hard to insert expect_true without breaking some sites, and there's also the problem that operator< and friends actually return bool not SymBool).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nb. You can do a horrible, horrible thing, and locally overwrite the meaning of <
inside TORCH_CHECK
via a macro so that you automagically go from using operator<
to lt
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this line is fantastic. Gotta love having just one entrypoint for a given thing.
start_val = sizes[dim] | ||
|
||
if end_val < start_val: | ||
end_val = start_val | ||
elif end_val >= sizes[dim]: | ||
elif end_val > sizes[dim]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two changes don't actually change semantics, but they're pretty important: frequently, we will know that end_val is <= sizes[dim], but we don't know if it == sizes[dim] or not. By branching only if it is truly out of bounds, we can statically determine which branch we go down. This is enough for split.
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: c93177d89e1a1d6c15f2b0765c77c9f6282ec891 Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: ee86135ddd20bcf671682f5c5bdb728421dc701e Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 11d5692b882a2825312335b046cf1c6d8e798e95 Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: ff5b6867ba795ec6ba84d921fa4caf98f0325bc8 Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 41bea703b17b260d7bee0d19359d9866f5c2bde5 Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: bf62bf051e6a82a35c44830eb8d1b9aa98fe9e99 Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: b1ba0de493b3141f54a844f683cdf94620d721ab Pull Request resolved: #106788
@pytorchbot merge -f "unstable only" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This pattern shows up in torchrec KeyedJaggedTensor. Most of the change in this PR is mechanical: whenever we failed an unbacked symint test due to just error checking, replace the conditional with something that calls expect_true (e.g., torch._check or TORCH_SYM_CHECK). Some of the changes are a bit more nuanced, I've commented on the PR accordingly. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#106788 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#106720
See #106788 for context. I think I don't actually need this for anything real, but this is pretty mild so might as well. Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
…7785) See #106788 for context. I think I don't actually need this for anything real, but this is pretty mild so might as well. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #107785 Approved by: https://github.com/albanD
Stack from ghstack (oldest at bottom):
This pattern shows up in torchrec KeyedJaggedTensor. Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).
Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.
Signed-off-by: Edward Z. Yang ezyang@meta.com
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov