-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Refactor Symint Deduping to separate pass #118938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118938
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 4db8b4e with merge base d9d8c2b ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward. potential fix for : #118224 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward. potential fix for : #118224 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
) -> GraphModule: | ||
graph = tracer.trace(root, concrete_args) | ||
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints | ||
dedupe_symints(graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh you want to unconditionally call it here? I guess we can...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that was the previous state
""" | ||
|
||
sym_dict = _SymHashingDict() | ||
resolvable_from_input_symints = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't things that are resolvable from inputs guaranteed to be symbols, rather than expressions? (I suppose in a rare occasion we could have created an input for a variable s0 which later got replaced into s1 + 3, but in this case it doesn't seem like a big deal to fail to dedupe to this input...)
graph.erase_node(node) | ||
elif all(n in resolvable_from_input_symints for n in node.all_input_nodes): | ||
sym_dict[val] = node | ||
resolvable_from_input_symints.add(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except this seems to imply that you're propagating this recursively? So what is even the point--the algorithm right now seems to only exclude nodes from being resolved if the SymInt has never been seen before in the graph.
sym_dict[val] = node | ||
elif existing_node := sym_dict.get(val): | ||
node.replace_all_uses_with(existing_node) | ||
graph.erase_node(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I am concerned that you may need to replay the computation, and just looking for the SymInt may not be enough. Consider the following graph:
def f(x):
y = torch.cat((x, x))
z = y.size(0)
... something involving z ...
Given that x: f32[s0], the val for z is s0 * 2
. This expression has never appeared, except from y. So you will fail to actually generate a size expression that doesn't keep a tensor live, since you're going to hang it off of torch.cat
. Actually, I guess you are calling this a SymInt deduping pass, so I suppose it isn't obvious the pass is supposed to deal with this, but I could swear part of the reason you were doing was to eliminate false dependencies on tensors that weren't actually real?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, i don't think this is comprehensive in deduping all symints. It would need further changes. Will leave that for a future PR as this is mostly supposed to be a refactoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling back from Jun 2024, I need to do this version of the pass :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm approving because this is a strict improvement over the existing situation, but I do have questions on the algo.
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward. potential fix for : #118224 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward. potential fix for : #118224 Pull Request resolved: #118938 Approved by: https://github.com/ezyang
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward. potential fix for : #118224 Pull Request resolved: #118938 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass.
We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.
potential fix for : #118224
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler