-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[Traceable FSDP2] Add partial-graph (graph-break) unit tests #131747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131747
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6a35e66 with merge base 89bdd9c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
*self._create_transformer_factory_fns(), "aot_eager", fullgraph=True | ||
) | ||
def test_transformer_backend_aot_eager(self): | ||
for fullgraph in [True, False]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: @parametrize("graph_break", [True, False]
) might make debugging the test easier if e.g. the graph break path fails at some point (example: https://github.com/pytorch/pytorch/blob/main/test/autograd/test_functional.py#L684)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with distributed is that @parametrize
directly would initialize NCCL PG for every parametrized test, which is slow. If you have what @yf225 has, it would reuse the same NCCL PG for all subtests.
I think @kwen2501 added MultiProcContinousTest
to maybe address this. I am not sure exactly the limitations though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh
torch._dynamo.graph_break() | ||
return orig_fn(*args, **kwargs) | ||
|
||
def _mock_sdpa(self, fullgraph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe the name should imply that you're just conditionally adding a graph break to sdpa? with _maybe_add_graph_break_to_sdpa(...)
).run(code) | ||
else: | ||
self.assertTrue( | ||
len(triton_codes) >= 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think it would be useful to be a bit stricter about what we assert? If I understand properly, my understanding is something like:
3 graphs total (2 fw's due to graph break, 1 unified bw due to compiled autograd)
fw 1: only comm is a all_gather_out
, and one set_() (performs the weight AG, but does not free the weight)
fw2: no comms, only contains a set_() (just finishes the fw compute and frees the weight)
bw: contains all_gather_out
, reduce_scatter_out
, and 2set_()
ops (fully gathers weight, does bw compute and fress)
(or at least, why is the assert >=3 and not just == 3?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I think we want to be more strict here - there is some recompile happening when there is graph break, I'll look into it with @anijain2305 and then add more strict checks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool to see graph breaks don't error :)
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 68 mandatory check(s) failed. The first few are:
Dig deeper by viewing the failures on hud |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #131747 Approved by: https://github.com/bdhirsh (cherry picked from commit 236d055)
ghstack-source-id: ee7a75f Pull Request resolved: pytorch/pytorch#131747
Stack from ghstack (oldest at bottom):
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @chauhang @penguinwu