Skip to content

Commit

Permalink
Update and expose ZeroRedundancyOptimizer docs (#52937)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #52937

Test Plan: Imported from OSS

Reviewed By: blefaudeux

Differential Revision: D26696938

Pulled By: mrshenli

fbshipit-source-id: dafb00e5c9f0c0c602f471fdcb6416bde74f806b
  • Loading branch information
mrshenli authored and facebook-github-bot committed Mar 2, 2021
1 parent a176c73 commit a586c02
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 67 deletions.
8 changes: 8 additions & 0 deletions docs/source/distributed.optim.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. role:: hidden
:class: hidden-section

Distributed Optimizers
======================

.. autoclass:: torch.distributed.optim.ZeroRedundancyOptimizer
:members:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Features described in this documentation are classified by release status:
torch.cuda.amp <amp>
torch.backends <backends>
torch.distributed <distributed>
torch.distributed.optim <distributed.optim>
torch.distributions <distributions>
torch.fft <fft>
futures
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/optim/test_zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def closure():
_ = optimizer.step(closure=closure)

# Update the optimizer state on the reference rank
optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK)
optimizer.consolidate_state_dict(to=RECIPIENT_RANK)

# Fetch the state on the reference rank
# - check that it has the correct size
Expand Down
179 changes: 113 additions & 66 deletions torch/distributed/optim/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Credits: classy_vision/generic/distributed_util.py
def _recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
"""
r"""
Recursively searches lists, tuples, dicts and copies tensors to device if
possible. Non-tensor values are passed as-is in the result.
Expand Down Expand Up @@ -55,7 +55,7 @@ def _broadcast_object(
group: object = dist.group.WORLD,
dist_device: torch.device = torch.device("cpu"),
) -> Any:
"""
r"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
Expand Down Expand Up @@ -85,35 +85,53 @@ def _get_global_rank(group: Any, rank: int) -> int:


class ZeroRedundancyOptimizer(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
::
opt = ZeroRedundancyOptimizer(params, optim=torch.optim.Adam, lr=0.01)
We use a greedy algorithm to pack a number of parameters at each rank.
Each parameter belongs to a single rank and is not divided among ranks.
The partition is arbitrary and does not correspond to the information flow for instance.
After each rank completed their parameter update, they broadcast
the new version of the parameters to all other ranks to synchronize
the parameters for next round forward/backward computation.
r"""
This class wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
and shards its states across ranks in the group as described by
ZeRO_. The optimizer instance in each rank is only responsible for
updating ``1 / world_size`` parameters and hence only needs to keep
``1 / world_size`` optimizer states. After parameters are updated locally,
each rank will broadcast its parameters to all other peers to keep all
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used
in conjuction with :class:`torch.nn.parallel.DistributedDataparallel` to
reduce per-rank peak memory consumption.
``ZeroRedundancyOptimizer`` use a greedy algorithm to pack a number of
parameters at each rank. Each parameter belongs to a single rank and is not
divided among ranks. The partition is arbitrary and might not match the
the parameter registration or usage order.
Arguments:
params (list of tensors):
parameters to be optimized
Keyword Args:
optim (torch.nn.Optimizer):
optimizer to shard
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
group (group):
torch.distributed group (default: group.WORLD)
parameters_as_bucket_views (bool):
whether to pack the parameters into bigger buckets, which speeds up communications.
If this is enabled, `params.data` should not be modified outside of ZeroRedundancyOptimizer.
**default: all trailing arguments will be forwarded to the requested optimizer
Keyword Args:
optim_class (:class:`torch.nn.Optimizer`): the class of the local
optimizer.
group (``ProcessGroup``, optional): ``torch.distributed``
``ProcessGroup`` (default: ``group.WORLD`` initialized by
:meth:`torch.distributed.init_process_group`).
parameters_as_bucket_views (bool): when enabled, parameters will
be packed into larger buckets to speed up communication and
``param.data`` fields will point to bucket views at different
offsets. When disabled, each individual parameter will be
communicated separately, but ``params.data`` will stay intact.
**default: all trailing arguments will be forwarded to the given optimizer.
Example::
>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>> ddp.parameters(),
>>> optim=torch.optim.Adam,
>>> lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()
.. warning: ZeroRedundancyOptimizer is experimental and subject to change.
Expand Down Expand Up @@ -175,18 +193,22 @@ def _clear_cache(self) -> None:
self._param_to_index_cache.clear()

def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
r"""
Add a param group to the :class:`Optimizer` s ``param_groups``.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
This can be useful when fine tuning a pre-trained network, as frozen
layers can be made trainable and added to the :class:`Optimizer` as
training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options
.. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
Calling this on a subset of the ranks will cause the training to hang, because communication primitives
are called depending on the managed parameters, and expect all the ranks to participate.
param_group (dict): Specifies what Tensors should be optimized
along with group specific optimization options.
.. warning: This method handles updating the shards on all partitions,
but needs to be called on all ranks. Calling this on a subset of the
ranks will cause the training to hang, because communication
primitives are called depending on the managed parameters, and
expect all the ranks to participate on the sane set of parameters.
"""

super().add_param_group(param_group)
Expand All @@ -202,10 +224,15 @@ def add_param_group(self, param_group: dict) -> None:
if self.parameters_as_bucket_view:
self._setup_flat_buffers()

def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank.
def consolidate_state_dict(self, to: int = 0) -> None:
r"""
Update the consolidated state_dict list, one per rank.
.. warning: This needs to be called on all replicas"""
Arguments:
to (int): the rank that receives the global states. (default: 0)
.. warning: This needs to be called on all replicas
"""

# Sync lr and other attributes in case its been updated
self._sync_param_groups(self.param_groups, self.optim.param_groups)
Expand All @@ -224,7 +251,7 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
global_rank = _get_global_rank(self.group, rank)

# This rank collects the whole state
if self.rank == recipient_rank:
if self.rank == to:
if rank == self.rank:
self._all_states.append(
_recursive_copy_to_device(
Expand Down Expand Up @@ -257,7 +284,7 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
dist_device=self._device,
)

elif rank != recipient_rank:
elif rank != to:
# Discard this tensor/rank, broadcast was being use for compatibility reasons
_ = _broadcast_object(
empty_messenger,
Expand All @@ -267,9 +294,11 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
)

def partition_parameters(self) -> List[List[Dict]]:
"""Partitions parameters across distributed data parallel ranks.
r"""
Partitions parameters across distributed data parallel ranks.
Returns: a list of ``param_groups`` (which is a list of dict) where each
Returns:
a list of ``param_groups`` (which is a list of dict) where each
element of the list contains the param_groups for a rank. Element 0
corresponds to rank 0, etc. We need all the ranks for the broadcast
inside ``step()``.
Expand All @@ -294,7 +323,8 @@ def partition_parameters(self) -> List[List[Dict]]:

@property
def _per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
r"""
Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
Expand All @@ -318,7 +348,7 @@ def _per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:

@property
def _param_to_rank(self) -> Dict[torch.Tensor, int]:
"""Look up table to match a given param with a data parallel rank"""
r"""Look up table to match a given param with a data parallel rank"""
if len(self._param_rank_cache) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
Expand All @@ -328,7 +358,10 @@ def _param_to_rank(self) -> Dict[torch.Tensor, int]:

@property
def _param_to_index(self) -> Dict[int, int]:
"""Hash table in between parameter indices in the global optimizer scheme, and the actual params"""
r"""
Hash table in between parameter indices in the global optimizer scheme,
and the actual params.
"""
if len(self._param_to_index_cache) == 0:
self._param_to_index_cache = {
id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))
Expand All @@ -338,22 +371,27 @@ def _param_to_index(self) -> Dict[int, int]:

@property
def _index_to_param(self) -> Dict[int, torch.Tensor]:
"""Hash table in between parameter indices in the global optimizer scheme, and the actual params"""
r"""
Hash table in between parameter indices in the global optimizer scheme,
and the actual params.
"""
if len(self._index_to_param_cache) == 0:
self._index_to_param_cache = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}

return self._index_to_param_cache

def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
"""Performs a single optimization step (parameter update).
r"""
Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
Returns:
optional loss, depends on the underlying optimizer
.. note: Any extra parameter is passed to the base optimizer as-is"""
.. note: Any extra parameter is passed to the base optimizer as-is
"""

# Check whether the model trainability graph changed
trainable_mask = list(map(_is_trainable, self._all_params))
Expand Down Expand Up @@ -397,7 +435,8 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
return loss

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.
r"""
Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
Expand All @@ -420,7 +459,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
ZeroRedundancyOptimizer._sync_param_groups(self.param_groups, self.optim.param_groups)

def local_state_dict(self) -> Dict:
"""Gets this rank's ``state_dict``.
r"""
Gets this rank's ``state_dict``.
Returns:
The state of the optimizer as a :class:`dict`.
Expand All @@ -433,16 +473,19 @@ def local_state_dict(self) -> Dict:
return self.optim.state_dict()

def state_dict(self) -> Dict[str, Any]:
"""
r"""
Returns:
the last known global optimizer state, which consist of a list of the shards.
the last known global optimizer state, which consist of a list of
the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
If the state has not been consolidated, this returns a shard's worth,
not the global state.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
Returning the global state is limited to the replica which was
responsible for the consolidation. The state may also not be up to
date, depending on when :meth:`consolidate_state_dict` was last called.
"""

if len(self._all_states) == 0:
Expand Down Expand Up @@ -476,28 +519,31 @@ def state_dict(self) -> Dict[str, Any]:

@staticmethod
def rank_local_state_dict(rank: int, state_dict: dict) -> dict:
"""Returns the local_state_dict for a given rank.
r"""
Returns the local_state_dict for a given rank.
Arguments:
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
rank (int): rank to get ``local_state_dict`` for
state_dict (dict): global ``state_dict``
"""
param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
return {"state": state_dict["state"][rank], "param_groups": param_groups}

@staticmethod
def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""
r"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""

for source_group, destination_group in zip(source, destination):
# Sync everything but the parameters
for k in filter(lambda x: x != "params", source_group.keys()):
destination_group[k] = source_group[k]

def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
`_update_trainable` is called.
r"""
Make all params which are on the same device and tied to the same rank
views of a single buffer. This is used at construction time, and anytime
parameter trainability is changed (frozen or unfrozen) and
``_update_trainable`` is called.
"""

for device, per_rank_params in self._per_device_params.items():
Expand Down Expand Up @@ -535,8 +581,9 @@ def _setup_flat_buffers(self) -> None:
self.buckets[device].append(torch.zeros(1, device=device))

def _update_trainable(self) -> None:
"""Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed.
r"""
Updates the partitioning and communication patterns if the trainability
(``requires_grad``) of some parameters changed.
"""

# Create the optim which will work on the param shard
Expand Down

0 comments on commit a586c02

Please sign in to comment.