Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use expect_true to make split with unbacked sizes work. #106788

Closed
wants to merge 8 commits into from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Aug 8, 2023

Stack from ghstack (oldest at bottom):

This pattern shows up in torchrec KeyedJaggedTensor. Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov

This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 8, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106788

Note: Links to docs will display an error until the docs builds have been completed.

✅ 1 Unrelated Failure

As of commit 0526cb2:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Aug 8, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: dcdfbb704ed64b9a95b02157d41fb55bd383c1f4
Pull Request resolved: #106788
@github-actions github-actions bot requested a review from albanD August 8, 2023 14:38
start = maybe_wrap_dim(start, cur_size);
TORCH_SYM_CHECK((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size)), "narrow(): start must be within [-cur_size, cur_size]");
if (start < 0) {
start = start + cur_size;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The start != cur_size test cannot be sym'ified since one branch is not an error condition. What I did here was inlined maybe_wrap_dim here, which conventionally tests -cur_size <= start < cur_size; so you can make it work without extra branching by just changing this condition to -cur_size <= start <= cur_size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can split out the inlining into its own PR if people would prefer.

@@ -948,7 +948,7 @@ def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callab
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')

if cond:
if torch.fx.experimental.symbolic_shapes.expect_true(cond):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This turns on expect_true for all our decomps/meta functions, pretty nice. (I couldn't do the same trick in C++; TORCH_CHECK is a very delicate macro and it was hard to insert expect_true without breaking some sites, and there's also the problem that operator< and friends actually return bool not SymBool).

Copy link
Collaborator

@lezcano lezcano Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nb. You can do a horrible, horrible thing, and locally overwrite the meaning of < inside TORCH_CHECK via a macro so that you automagically go from using operator< to lt.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this line is fantastic. Gotta love having just one entrypoint for a given thing.

start_val = sizes[dim]

if end_val < start_val:
end_val = start_val
elif end_val >= sizes[dim]:
elif end_val > sizes[dim]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two changes don't actually change semantics, but they're pretty important: frequently, we will know that end_val is <= sizes[dim], but we don't know if it == sizes[dim] or not. By branching only if it is truly out of bounds, we can statically determine which branch we go down. This is enough for split.

@albanD albanD removed their request for review August 8, 2023 14:48
@ezyang ezyang requested a review from ysiraichi August 8, 2023 15:00
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 8, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: c93177d89e1a1d6c15f2b0765c77c9f6282ec891
Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 8, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: ee86135ddd20bcf671682f5c5bdb728421dc701e
Pull Request resolved: #106788
@ezyang ezyang added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Aug 9, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 9, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 11d5692b882a2825312335b046cf1c6d8e798e95
Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 14, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: ff5b6867ba795ec6ba84d921fa4caf98f0325bc8
Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 14, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 41bea703b17b260d7bee0d19359d9866f5c2bde5
Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 15, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: bf62bf051e6a82a35c44830eb8d1b9aa98fe9e99
Pull Request resolved: #106788
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Aug 15, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: b1ba0de493b3141f54a844f683cdf94620d721ab
Pull Request resolved: #106788
@ezyang
Copy link
Contributor Author

ezyang commented Aug 15, 2023

@pytorchbot merge -f "unstable only"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

summerdo pushed a commit to summerdo/pytorch that referenced this pull request Aug 17, 2023
This pattern shows up in torchrec KeyedJaggedTensor.  Most
of the change in this PR is mechanical: whenever we failed
an unbacked symint test due to just error checking, replace the
conditional with something that calls expect_true (e.g.,
torch._check or TORCH_SYM_CHECK).

Some of the changes are a bit more nuanced, I've commented on the PR
accordingly.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#106788
Approved by: https://github.com/lezcano
ghstack dependencies: pytorch#106720
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/2289/head branch August 19, 2023 14:16
ezyang added a commit that referenced this pull request Aug 23, 2023
See #106788 for context.

I think I don't actually need this for anything real, but this is pretty
mild so might as well.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: b905ba0dbcce50af5a2411cedc89c051add745ea
Pull Request resolved: #107785
ezyang added a commit that referenced this pull request Aug 23, 2023
See #106788 for context.

I think I don't actually need this for anything real, but this is pretty
mild so might as well.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 23, 2023
…7785)

See #106788 for context.

I think I don't actually need this for anything real, but this is pretty
mild so might as well.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #107785
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants