Skip to content

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Feb 2, 2024

Stack from ghstack (oldest at bottom):

Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass.

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : #118224

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

Copy link

pytorch-bot bot commented Feb 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118938

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 4db8b4e with merge base d9d8c2b (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@eellison eellison requested a review from ezyang February 2, 2024 01:10
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. 

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : #118224


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@eellison eellison removed the release notes: fx release notes category label Feb 2, 2024
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. 

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : #118224


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
) -> GraphModule:
graph = tracer.trace(root, concrete_args)
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
dedupe_symints(graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you want to unconditionally call it here? I guess we can...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that was the previous state

"""

sym_dict = _SymHashingDict()
resolvable_from_input_symints = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't things that are resolvable from inputs guaranteed to be symbols, rather than expressions? (I suppose in a rare occasion we could have created an input for a variable s0 which later got replaced into s1 + 3, but in this case it doesn't seem like a big deal to fail to dedupe to this input...)

graph.erase_node(node)
elif all(n in resolvable_from_input_symints for n in node.all_input_nodes):
sym_dict[val] = node
resolvable_from_input_symints.add(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except this seems to imply that you're propagating this recursively? So what is even the point--the algorithm right now seems to only exclude nodes from being resolved if the SymInt has never been seen before in the graph.

sym_dict[val] = node
elif existing_node := sym_dict.get(val):
node.replace_all_uses_with(existing_node)
graph.erase_node(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I am concerned that you may need to replay the computation, and just looking for the SymInt may not be enough. Consider the following graph:

def f(x):
  y = torch.cat((x, x))
  z = y.size(0)
  ... something involving z ...

Given that x: f32[s0], the val for z is s0 * 2. This expression has never appeared, except from y. So you will fail to actually generate a size expression that doesn't keep a tensor live, since you're going to hang it off of torch.cat. Actually, I guess you are calling this a SymInt deduping pass, so I suppose it isn't obvious the pass is supposed to deal with this, but I could swear part of the reason you were doing was to eliminate false dependencies on tensors that weren't actually real?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, i don't think this is comprehensive in deduping all symints. It would need further changes. Will leave that for a future PR as this is mostly supposed to be a refactoring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling back from Jun 2024, I need to do this version of the pass :P

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm approving because this is a strict improvement over the existing situation, but I do have questions on the algo.

Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass. 

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : #118224


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@eellison eellison added the topic: not user facing topic category label Feb 6, 2024
@eellison
Copy link
Contributor Author

eellison commented Feb 6, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 6, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Feb 8, 2024
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass.

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : #118224

Pull Request resolved: #118938
Approved by: https://github.com/ezyang
clee2000 pushed a commit that referenced this pull request Feb 14, 2024
Previously Symint Deduping was done during proxy tracing which made it more difficult to reason about. This refactors the deduping to a separate pass.

We only dedupe symints which are resolvable from input symint nodes so as to avoid inducing a dependency on the backward in the forward.

potential fix for : #118224

Pull Request resolved: #118938
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the gh/eellison/589/head branch March 8, 2024 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants