-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Composable API: replicate and DistributedState
#87649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87649
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 2 PendingAs of commit d10b289: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
| from typing import List, Tuple | ||
|
|
||
|
|
||
| class DistributedState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: @mrshenli could we move 'DistributedState' to contract.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I can move it after @mrshenli land contract.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this be cleaned up?
|
|
||
|
|
||
| def replicate( | ||
| *modules: nn.Module, dist_state: ReplicateState = _default_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we need to expose 'dist_state' to users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's open for discussion. Maybe we don't need it and let all modules share the same dist_state.
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
|
||
|
|
||
| def replicate( | ||
| *modules: nn.Module, dist_state: ReplicateState = _default_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this * for? I thought this could just be replicate(module) instead? in other words, do we expect user passing in a list/tuple of modules? My impression is that module is a tree like structure and there's always a root of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows user to pass in a single module or multiple modules, for example:
replicate(m) or replicate(m1, m2)
In this case, m1 and m2 can be different modules in the tree.
| class ReplicateState(DistributedState): | ||
| def __init__(self) -> None: | ||
| self.modules: List[nn.Module] = [] | ||
| self.parameters: List[nn.Parameter] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of this parameters field? is it all the parameters of the module, how to correlate the parameters to module.paramters() inside modules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The replicate() API can be called multiple times for different modules in a model. This parameter field is to keep a reference of all parameters of these modules. So that they can be managed together, for example bucketizing in DDP's all-reduce.
| def forward_pre_hook( | ||
| self, module: nn.Module, input: Tuple[torch.Tensor] | ||
| ) -> None: | ||
| if not self.has_initialized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this forward_pre_hook install to all modules inside self.modules? or this is just searching the module inside self.modules and apply the prehook on that module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean hook on self.modules only, or recursively on all sub-modules of self.modules?
My current idea is former, self.modules only. Otherwise it may hurt performance. But it may change when we come to more details in implementation.
| self.parameters: List[nn.Parameter] = [] | ||
| self.has_initialized: bool = False | ||
|
|
||
| def add_modules(self, *modules: nn.Module) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the workflow of add_modules, forward_pre_hook, should they happen in a specific order? could you add a test to demonstrate how the workflow should work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do
[ghstack-poisoned]
replicate and DistributedStatereplicate and DistributedState
[ghstack-poisoned]
| self._param_list.extend( | ||
| param for param in module.parameters() if param.requires_grad | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, if it is top-down search, the implementation seems to collect duplicate parameters?
| self._param_list.extend( | ||
| param for param in module.parameters() if param.requires_grad | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
module.parameters() will include children's parameters, for the case like this, b and d are marked to be replicated, e is marked to be sharded, what will 'self._param_list' be?
root
/ | \
a b c
/ \
d e
|
|
||
| def _recursive_collect_params(self, module: nn.Module) -> None: | ||
| if ( | ||
| getattr(module, "_distributed_state", None) is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also seems module._distributed_state is not None for modules that are marked as replicated at this point, so it always returns and did not fill in 'self._param_list'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
| self._param_list.extend( | ||
| param for param in module.parameters() if param.requires_grad | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the best way is to write some unit tests to verify the parameters are collected as expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add test coverage
|
|
||
| def forward(self, *inputs, **kwargs): | ||
| self.pre_forward(*inputs, **kwargs) | ||
| with torch.autograd.profiler.record_function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not the same as existing DDP? As the record_function used to wrap pre_forward and post_forward here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, it's not exactly the same.
Since it's now broken into 2 other functions, I added record_function("DistributedDataParallel.pre_forward") and record_function("DistributedDataParallel.post_forward") in each of them. Does this make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be great if we can have a clean up for the _ddp.py soon to remove unnecessary features. We usually don't leave duplicated code in PyTorch. I understand the intention is to move fast. But even in that case, let's try to only allow that for a short time window.
| ) | ||
| return output | ||
|
|
||
| def forward(self, *inputs, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this _ddp.py file is just for composable API, do we even need this forward method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will clean it up soon
[ghstack-poisoned]
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is: - create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API. - create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object - install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward [ghstack-poisoned]
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is: - create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API. - create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object - install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward [ghstack-poisoned]
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is: - create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API. - create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object - install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward [ghstack-poisoned]
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is: - create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API. - create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object - install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward [ghstack-poisoned]
|
|
||
| def test_replicate(self): | ||
| dist.init_process_group( | ||
| backend="gloo", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's test on both gloo and nccl
| local_batch_size = 1 | ||
| global_batch_size = self.world_size * local_batch_size | ||
| model, input, target = self._prepare_module(global_batch_size) | ||
| replicate_model = mark_root_module(replicate(deepcopy(model))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add more test cases:
- replicate one submodule instead of the root module
- replicate more than one submodules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also a test case where there are some submodules of the replicated local root module are annotated by fully_shard()
| from typing import List, Tuple | ||
|
|
||
|
|
||
| class DistributedState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this be cleaned up?
| for module in modules: | ||
| self.modules.append(module) | ||
| replicate.state(module)._distributed_state = self | ||
| replicate.state(module)._params_collected = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering how other states in ddp constructor are populated?
| for module in self.modules: | ||
| self._recursive_collect_params(module) | ||
|
|
||
| self._ddp = _ddp.DistributedDataParallel(self._param_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, all the states in ddp constructor is owned in self._ddp...
assume we will not use this soon, as it is still monkey patching. instead, we will make all the states owned by replicate.state() object?
| self, module: nn.Module, input: Tuple[torch.Tensor] | ||
| ) -> None: | ||
| self.init_helper() | ||
| self._ddp.pre_forward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, we need to make pre_forward() accept a replicate.state() object once get rid of self._ddp?
| @contract | ||
| def replicate( | ||
| module: nn.Module, # NOTE: contract now supports single module only | ||
| dist_state: ReplicateState = _default_state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dist_state is used internally, could we remove it from user facing API here?
also, replicate needs to have similar argument as exiting DDP API, can work with important features that DDP provided, like static_graph, gradient_as_bucket_view, etc
|
|
||
|
|
||
| def mark_root_module( | ||
| module: nn.Module, dist_state: ReplicateState = _default_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, not exposing dist_state to users
| >>> module = nn.Linear(3, 3) | ||
| >>> replicate(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the example does not reflect the 'mark_root_module'
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is: - create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API. - create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object - install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp to get this version backed up.
Synced up offline, following PRs will be sent out:
- remove mark_root_module()
- clean up replicate() constructor without distState argument and add important arguments such as static_graph, find_unused_parameter, etc
- clean up self._ddp
- add more tests like interact with fully_shard.py
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
replicate and DistributedStatereplicate and DistributedState
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is: - create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API. - create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object - install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward Pull Request resolved: pytorch#87649 Approved by: https://github.com/zhaojuanmao
Stack from ghstack (oldest at bottom):
replicateandDistributedState#87649This PR adds the first version of the
replicate()composable API. For this prototype version, I try to reuse as much code from existingDistributedDataParallelas possible, and iterate on it in later changes. The basic idea of this prototype is:ReplicateStateobject. It internally uses aParameterListmodule to hold all parameters of modules marked byreplicate()API._ddpobject, which reuses existingDistributedDataParallelimplementation, and wraps theParameterListobject_ddpto run initialization and forward