Skip to content

Conversation

yhcharles
Copy link
Contributor

@yhcharles yhcharles commented Oct 24, 2022

Stack from ghstack (oldest at bottom):

This PR adds the first version of the replicate() composable API. For this prototype version, I try to reuse as much code from existing DistributedDataParallel as possible, and iterate on it in later changes. The basic idea of this prototype is:

  • create a ReplicateState object. It internally uses a ParameterList module to hold all parameters of modules marked by replicate() API.
  • create an internal _ddp object, which reuses existing DistributedDataParallel implementation, and wraps the ParameterList object
  • install pre-forward and after-forward hooks on the root module, which calls methods of _ddp to run initialization and forward

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 24, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87649

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 2 Pending

As of commit d10b289:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

yhcharles added a commit that referenced this pull request Oct 24, 2022
ghstack-source-id: a120dbd
Pull Request resolved: #87649
yhcharles added a commit that referenced this pull request Oct 24, 2022
ghstack-source-id: 235abc0
Pull Request resolved: #87649
from typing import List, Tuple


class DistributedState:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: @mrshenli could we move 'DistributedState' to contract.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can move it after @mrshenli land contract.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be cleaned up?



def replicate(
*modules: nn.Module, dist_state: ReplicateState = _default_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need to expose 'dist_state' to users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's open for discussion. Maybe we don't need it and let all modules share the same dist_state.

@yhcharles yhcharles mentioned this pull request Oct 25, 2022
@yhcharles yhcharles marked this pull request as ready for review October 25, 2022 18:59
yhcharles added a commit that referenced this pull request Oct 25, 2022
ghstack-source-id: a634de0
Pull Request resolved: #87649
@yhcharles yhcharles requested a review from wanchaol October 26, 2022 03:14


def replicate(
*modules: nn.Module, dist_state: ReplicateState = _default_state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this * for? I thought this could just be replicate(module) instead? in other words, do we expect user passing in a list/tuple of modules? My impression is that module is a tree like structure and there's always a root of it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows user to pass in a single module or multiple modules, for example:
replicate(m) or replicate(m1, m2)
In this case, m1 and m2 can be different modules in the tree.

class ReplicateState(DistributedState):
def __init__(self) -> None:
self.modules: List[nn.Module] = []
self.parameters: List[nn.Parameter] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of this parameters field? is it all the parameters of the module, how to correlate the parameters to module.paramters() inside modules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The replicate() API can be called multiple times for different modules in a model. This parameter field is to keep a reference of all parameters of these modules. So that they can be managed together, for example bucketizing in DDP's all-reduce.

def forward_pre_hook(
self, module: nn.Module, input: Tuple[torch.Tensor]
) -> None:
if not self.has_initialized:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this forward_pre_hook install to all modules inside self.modules? or this is just searching the module inside self.modules and apply the prehook on that module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean hook on self.modules only, or recursively on all sub-modules of self.modules?

My current idea is former, self.modules only. Otherwise it may hurt performance. But it may change when we come to more details in implementation.

self.parameters: List[nn.Parameter] = []
self.has_initialized: bool = False

def add_modules(self, *modules: nn.Module) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the workflow of add_modules, forward_pre_hook, should they happen in a specific order? could you add a test to demonstrate how the workflow should work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do

yhcharles added a commit that referenced this pull request Oct 27, 2022
ghstack-source-id: 4a7cea5
Pull Request resolved: #87649
@yhcharles yhcharles changed the title Draft of replicate and DistributedState [WIP] Composable API: replicate and DistributedState Oct 27, 2022
@pytorch-bot pytorch-bot bot added the release notes: distributed (ddp) release notes category label Oct 27, 2022
@yhcharles yhcharles added the topic: not user facing topic category label Oct 27, 2022
Comment on lines 30 to 32
self._param_list.extend(
param for param in module.parameters() if param.requires_grad
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, if it is top-down search, the implementation seems to collect duplicate parameters?

Comment on lines 30 to 32
self._param_list.extend(
param for param in module.parameters() if param.requires_grad
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module.parameters() will include children's parameters, for the case like this, b and d are marked to be replicated, e is marked to be sharded, what will 'self._param_list' be?

          root
        /     |     \
      a      b     c
            /    \
           d    e


def _recursive_collect_params(self, module: nn.Module) -> None:
if (
getattr(module, "_distributed_state", None) is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also seems module._distributed_state is not None for modules that are marked as replicated at this point, so it always returns and did not fill in 'self._param_list'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Comment on lines 30 to 32
self._param_list.extend(
param for param in module.parameters() if param.requires_grad
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the best way is to write some unit tests to verify the parameters are collected as expected

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add test coverage


def forward(self, *inputs, **kwargs):
self.pre_forward(*inputs, **kwargs)
with torch.autograd.profiler.record_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the same as existing DDP? As the record_function used to wrap pre_forward and post_forward here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, it's not exactly the same.
Since it's now broken into 2 other functions, I added record_function("DistributedDataParallel.pre_forward") and record_function("DistributedDataParallel.post_forward") in each of them. Does this make sense?

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be great if we can have a clean up for the _ddp.py soon to remove unnecessary features. We usually don't leave duplicated code in PyTorch. I understand the intention is to move fast. But even in that case, let's try to only allow that for a short time window.

)
return output

def forward(self, *inputs, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this _ddp.py file is just for composable API, do we even need this forward method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will clean it up soon

yhcharles added a commit that referenced this pull request Nov 4, 2022
ghstack-source-id: 98bfc9e
Pull Request resolved: #87649
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward


[ghstack-poisoned]
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward


[ghstack-poisoned]
yhcharles added a commit that referenced this pull request Nov 4, 2022
ghstack-source-id: 1f2c39a
Pull Request resolved: #87649
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward


[ghstack-poisoned]
yhcharles added a commit that referenced this pull request Nov 5, 2022
ghstack-source-id: 7578906
Pull Request resolved: #87649
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward


[ghstack-poisoned]
yhcharles added a commit that referenced this pull request Nov 7, 2022
ghstack-source-id: 6618b28
Pull Request resolved: #87649

def test_replicate(self):
dist.init_process_group(
backend="gloo",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's test on both gloo and nccl

local_batch_size = 1
global_batch_size = self.world_size * local_batch_size
model, input, target = self._prepare_module(global_batch_size)
replicate_model = mark_root_module(replicate(deepcopy(model)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add more test cases:

  1. replicate one submodule instead of the root module
  2. replicate more than one submodules

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also a test case where there are some submodules of the replicated local root module are annotated by fully_shard()

from typing import List, Tuple


class DistributedState:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be cleaned up?

for module in modules:
self.modules.append(module)
replicate.state(module)._distributed_state = self
replicate.state(module)._params_collected = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering how other states in ddp constructor are populated?

for module in self.modules:
self._recursive_collect_params(module)

self._ddp = _ddp.DistributedDataParallel(self._param_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, all the states in ddp constructor is owned in self._ddp...

assume we will not use this soon, as it is still monkey patching. instead, we will make all the states owned by replicate.state() object?

self, module: nn.Module, input: Tuple[torch.Tensor]
) -> None:
self.init_helper()
self._ddp.pre_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, we need to make pre_forward() accept a replicate.state() object once get rid of self._ddp?

@contract
def replicate(
module: nn.Module, # NOTE: contract now supports single module only
dist_state: ReplicateState = _default_state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dist_state is used internally, could we remove it from user facing API here?

also, replicate needs to have similar argument as exiting DDP API, can work with important features that DDP provided, like static_graph, gradient_as_bucket_view, etc



def mark_root_module(
module: nn.Module, dist_state: ReplicateState = _default_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, not exposing dist_state to users

Comment on lines +99 to +100
>>> module = nn.Linear(3, 3)
>>> replicate(module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the example does not reflect the 'mark_root_module'

This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward


[ghstack-poisoned]
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to get this version backed up.

Synced up offline, following PRs will be sent out:

  1. remove mark_root_module()
  2. clean up replicate() constructor without distState argument and add important arguments such as static_graph, find_unused_parameter, etc
  3. clean up self._ddp
  4. add more tests like interact with fully_shard.py

@yhcharles
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 17, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yhcharles yhcharles changed the title [WIP] Composable API: replicate and DistributedState Composable API: replicate and DistributedState Nov 17, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward

Pull Request resolved: pytorch#87649
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/yhcharles/6/head branch June 8, 2023 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (ddp) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants