Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo][moco] Disallow_in_graph distributed APIs #100071

Closed
wants to merge 6 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 26, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100071

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures

As of commit 16687a2:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

anijain2305 added a commit that referenced this pull request Apr 26, 2023
ghstack-source-id: 66e2cebe7268d9f475a3b1a4ce5389c63dc8980d
Pull Request resolved: #100071
@anijain2305 anijain2305 added the topic: not user facing topic category label Apr 26, 2023
@anijain2305
Copy link
Contributor Author

@wconstab Wondering if you know of a better way here, instead of just graph breaking on them?

@wconstab
Copy link
Contributor

Wondering if you know of a better way here, instead of just graph breaking on them?

No, we have to graph-break on them.

The missing piece is how to ensure we graph-break on the full set of them. I thought I stamped a PR (maybe from @yanboliang) a while back that skipped a large chunk of torch.distributed- looks like that did not land or else I can't find remnants of it.

We should disable for all of the collectives besides the new 'traceable/functional' ones (in _functional_collectives.py).

@yanboliang
Copy link
Contributor

The missing piece is how to ensure we graph-break on the full set of them. I thought I stamped a PR (maybe from @yanboliang) a while back that skipped a large chunk of torch.distributed- looks like that did not land or else I can't find remnants of it.

Yes, that PR has been reverted multiple times, I finally use regex matching to skip torchrec.distribued(make it support torch.package), but only be used internally behind is_fbcode. I think we should skip these files right now, not only because they are collectives ops, but also they have memory leak issue due to their implementation.

anijain2305 added a commit that referenced this pull request Apr 26, 2023
ghstack-source-id: 66e2cebe7268d9f475a3b1a4ce5389c63dc8980d
Pull Request resolved: #100071
cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 26, 2023
ghstack-source-id: 25b947f582f7477f52b6d7d47e8b1d7785ba4d80
Pull Request resolved: #100071
# There is no clear definition of torch.distributed ops. This helper set allows
# TorchDynamo to selectively disallow all the distributed ops from the Fx
# graphs.
distributed_c10d_ops = set()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab Is this too hacky? Problem is there is no easy way to tell what's a torch "op" and whats not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe cc @H-Huang @kwen2501 -- do we expect a function to use the exception_handler decorator if and only if it is a c10d op?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hah yea its kinda hacky.

But otoh i think it's reasonable to have a set of dynamo_unsupported_distributed_c10d_ops and explicitly add them all. We should just get folks like @H-Huang to agree on the set of ops to initially tag.

anijain2305 added a commit that referenced this pull request Apr 27, 2023
ghstack-source-id: 25b947f582f7477f52b6d7d47e8b1d7785ba4d80
Pull Request resolved: #100071
cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 27, 2023
ghstack-source-id: d0c52747c5f5b618733411021aa8245a5cd7bbde
Pull Request resolved: #100071
cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 27, 2023
ghstack-source-id: 0fe6edef674a88c22742b9fa57f453c3614c88ef
Pull Request resolved: #100071
cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Apr 27, 2023
ghstack-source-id: 0fe6edef674a88c22742b9fa57f453c3614c88ef
Pull Request resolved: #100071

# This ops are not friently to TorchDynamo. So, we decide to disallow these ops
# in FX graph, allowing them to run them on eager, with torch.compile.
dynamo_unsupported_distributed_c10d_ops = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am missing the context here. Is the idea just to exclude all the operations defined in distributed_c10d.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or specifically, the operators that do not work well with TorchDynamo tracing and we want to graph break (i.e. fallback to eager) on them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang ping

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then would it be possible to just get all operations of the module through something like:
https://stackoverflow.com/questions/139180/how-to-list-all-functions-in-a-module

Then we don't have to create this list in distributed_c10d and worry about keeping it up to date

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang Yes, so we tried dir(module) and __dict__ etc but then it gives a large number of functions and many of them not really "operators".

Another way I tried earlier in a commit - 8e6dee3
was to record the ops in exception_handler. The assumption was that anyone adding a new op will decorate it withe exception handler. But, that looked little hacky as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for clarifying. This looks like the route to go then. Should init_process_group be included in this list? This is not a conventional "operator"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. Removed it from the list.

cc soumith voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 2, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@anijain2305
Copy link
Contributor Author

@pytorchbot merge -f "unrelated CI error"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants