-
Notifications
You must be signed in to change notification settings - Fork 25.5k
[Replicate][Pipeline Parallelism] integration of new replicate function with pipeline parallelism #164031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…on with pipeline parallelism [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164031
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b33e7fc with merge base bec6541 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…cate function with pipeline parallelism" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
…cate function with pipeline parallelism" **Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. **Test Case** 1. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
…cate function with pipeline parallelism" **Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. **Test Case** 1. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
…cate function with pipeline parallelism" **Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. **Test Case** 1. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
result = perform_backward(backward_type)() | ||
|
||
# If submod is a Replicate module | ||
elif isinstance(self.submod, ReplicateModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I compared ReplicateModule and FSDPModule line by line, but seems they are almost identical. what's the core diffrerence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only difference is we need a way to distinguish modules that have been replicated vs fully-sharded so we know whether to get replicate state or fully shard state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, agree here. The only difference i see is from fully_shard.state()
and replicate.state()
. To reduce code duplication, i think you could define a helper method like _handle_fsdp_module_backward
that takes in a lamba function.
I'm also not sure why ReplicateModule
and FSDPModule
are different modules if it seems like they support the same things. We should look into that first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got you. is it possible to do this?
I don't mind unit test to be verbose. But I really hope the core code stay dry
elif isinstance(self.submod, FSDPModule): # ReplicateModule is also a FSDPModule
...
if isinstance(self.submod, ReplicateModule):
state = replicate.state(self.submod)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, if ReplicateModule is also an FSDP Module, then I like what @weifengpy proposed we should do that!
test/distributed/_composable/test_composability/test_pp_composability.py
Show resolved
Hide resolved
ref_pipeline_schedule.step(target=labels, losses=ref_losses) | ||
|
||
for loss, ref_loss in zip(losses, ref_losses): | ||
self.assertEqual(loss, ref_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a lot of our tests we compare single-device model grads with pipeline parallel model grads. This test just compares the loss of two pipelines. Might be out of scope to add this though, but maybe can add a TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, my only goals with this were to confirm that replicate + pipeline parallelism actually runs and that adding replicate to a pipeline doesn't affect it. I would prefer to make this a TODO that I can work on while running MAST tests in a few days
…cate function with pipeline parallelism" **Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. **Test Case** 1. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
…cate function with pipeline parallelism" **Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature. **Test Case** 1. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good on the fsdp part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature.
Test Case
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci