Skip to content

Commit

Permalink
Fix the doc of PostLocalSGDState (#72792)
Browse files Browse the repository at this point in the history
Summary:
The first arg of `PostLocalSGDState` ctor, `process_group`, cannot be empty. Here to simplify the usage, does not even create a subgroup explicitly.

See the example in unit test: https://github.com/pytorch/pytorch/blob/4feef6c97092cfde7d57a97d8390a79551e92369/torch/testing/_internal/distributed/distributed_test.py#L4260

Pull Request resolved: #72792

Reviewed By: samdow

Differential Revision: D34213221

Pulled By: rohan-varma

fbshipit-source-id: 078343f3ee138e175bf835897f190032eb970662
(cherry picked from commit bf90af7)
  • Loading branch information
wayi1 authored and pytorchmergebot committed Feb 15, 2022
1 parent 8a43aa9 commit 8b08478
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 1 addition & 2 deletions torch/distributed/algorithms/model_averaging/averagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ class PeriodicModelAverager(ModelAverager):
>>> module, device_ids=[rank], output_device=rank
>>> )
>>> # Register a post-localSGD communication hook.
>>> subgroup, subgroups = dist.new_subgroups()
>>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
Expand Down
3 changes: 1 addition & 2 deletions torch/distributed/optim/post_localSGD_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
>>> )
>>>
>>> # Register a post-localSGD communication hook.
>>> subgroup, subgroups = dist.new_subgroups()
>>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
Expand Down

0 comments on commit 8b08478

Please sign in to comment.