-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fix issues with generalized_scatter and setitem allocated unbacked symbols. #164341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164341
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 95f2e2d with merge base 219fb6a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ted unbacked symbols." Three fixes: 1. When doing t[x] +=1 if x is unbacked we could allocate a new unbacked symbol during the the indexing of t[x] but the output size/stride does not depend on it in this case. it's self consumed during meta tracing so we shall ignore it. 2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints but those do not effect final output, we also shall ignore them. 3.Before accessing strides in lowering we shall materialize. Address #162110 and #114293 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela [ghstack-poisoned]
…nbacked symbols." Three fixes: 1. When doing t[x] +=1 if x is unbacked we could allocate a new unbacked symbol during the the indexing of t[x] but the output size/stride does not depend on it in this case. it's self consumed during meta tracing so we shall ignore it. 2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints but those do not effect final output, we also shall ignore them. 3.Before accessing strides in lowering we shall materialize. Address #114293 and #131911 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela [ghstack-poisoned]
@ezyang during the meta_select we allocate new unbacked symint to represent the storage offset in the output of the select op. (because we do not know if u0>=0 or u0<0) updated summary. |
…unbacked symbols." Three fixes: 1. When doing t[u0] +=1 if x is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0. but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it. 2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints but those do not effect final output, we also shall ignore them. 3.Before accessing strides in lowering we shall materialize. Address #114293 and #131911 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela [ghstack-poisoned]
if fake_mode and fake_mode.shape_env | ||
else nullcontext() | ||
): | ||
tmp = view.target(tmp, *fake_args, **fake_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also confused here. You are going to have view.target inside of the graph now right? So tmp DOES have the unbacked symbol, and you should resolve unbacked symbols here (instead of suppressing them)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is Not the function that does the de-composition ,
this is the function used to trace the _generalized_scatter.
for the function that does the decomposition we need to do what you said right which i am doing on the PR that is right on top of this one on the stack.
#164948
for this function we self consume the views but return return inp on line 98 which would not have a reference to the new symbols in its example input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vc me if you don't agree
We discussed this offline for
for the generalized_scatter the current approach is fine @bobrenjc93 is looking at this problem more generally for High order ops. |
torch/_dynamo/variables/tensor.py
Outdated
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): | ||
|
||
# Ignore fresh unbacked symbols that could arise from the internal indexing (selection), | ||
# that happen in code like t[idx] += 1 when idx is unabacked. Namely the selection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/unabacked/unbacked
# Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this, | ||
# we use as_strided instead. | ||
# Removing this branch will cause test_unbacked_select_index_with_check to fail. | ||
x.realize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment?
assert len(unbacked_bindings) == 1, unbacked_bindings | ||
unbacked_offset_sym, _ = next(iter(unbacked_bindings.items())) | ||
|
||
x.realize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Starting merge as part of PR stack under #164948 |
…unbacked symbols." Three fixes: 1. When doing t[u0] +=1 if u0 is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0. but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it. 2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints but those do not effect final output, we also shall ignore them. 3.Before accessing strides in lowering we shall materialize. Address #114293 and #131911 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela [ghstack-poisoned]
address comments and rebase |
@pytorchmergebot merge |
This PR has pending changes requested. Please address the comments and update the PR before merging. |
Starting merge as part of PR stack under #164948 |
…g generalize_scatter decomp (#164948) Two fixes: 1. in rein_place pass, set unbacked bindings for newly created nodes. 2. In inductor, ComputeBuffer used to miss detecting some used symbols, fixed that. Pull Request resolved: #164948 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #164341
Stack from ghstack (oldest at bottom):
Three fixes:
When doing t[u0] +=1 if u0 is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0. but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it.
Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints
but those do not effect final output, we also shall ignore them.
3.Before accessing strides in lowering we shall materialize.
Address #114293 and #131911
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela