New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dynamo] fix dynamo + DTensor to work with 2d #108329
Conversation
This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108329
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2452522 with merge base e68b3ad (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping ghstack-source-id: 1922a173c9b148a8a6bfe19e820bbebd531435dd Pull Request resolved: #108329
pair debugged with wconstab and we found some issue in both dynamo and the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping. [ghstack-poisoned]
pair debugged with @wconstab and we found some issue in both dynamo and the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping. ghstack-source-id: 1922a173c9b148a8a6bfe19e820bbebd531435dd Pull Request resolved: #108329
pair debugged with wconstab and we found some issue in both dynamo and the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov [ghstack-poisoned]
pair debugged with @wconstab and we found some issue in both dynamo and the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping. ghstack-source-id: 4c26e2721a36a92d0988044f4c4fdc7491dc6dfd Pull Request resolved: #108329
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, but definitely add the assert for kwargs, might even be easy enough to just support kwargs too
process_group=fsdp_pg, | ||
device_id=self.rank, | ||
use_orig_params=True, | ||
) | ||
|
||
# TODO: once aot autograd support is ready we can just use default backend | ||
compiled_2d = torch.compile(fsdp_2d, backend="eager") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we might also want to test inductor backend, as it does break sometimes more than eager, due to customized logic for some parts of aot-autograd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch/_dynamo/variables/torch.py
Outdated
@@ -572,6 +572,7 @@ def get_state_from_generator(): | |||
elif is_from_local(self.value): | |||
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function | |||
# and rewrite args to have only proxyable args, then insert call_function | |||
# TODO: support cases where device_mesh + placements specified as kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mby good to assert kwargs is empty (or that it just contains the 1 bool you expect, etc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it's not quite hard to directly support kwargs so I added that directly
pair debugged with wconstab and we found some issue in both dynamo and the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping. cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov [ghstack-poisoned]
pair debugged with @wconstab and we found some issue in both dynamo and the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration so that the current graph break FSDP can work with tensor parallel by moving the torch.compile after FSDP wrapping. ghstack-source-id: 4f48e003224e0d48e32b3db57923859d91b50e0e Pull Request resolved: #108329
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Nice! |
Stack from ghstack (oldest at bottom):
pair debugged with @wconstab and we found some issue in both dynamo and
the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration
so that the current graph break FSDP can work with tensor parallel by moving
the torch.compile after FSDP wrapping.
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov