Skip to content

Commit

Permalink
Improve new_group example in the context of SyncBatchNorm (#48897)
Browse files Browse the repository at this point in the history
Summary:
Closes #48804
Improves some documentation/example in SyncBN docs to clearly show that each rank must call into all `new_group()` calls for creating process subgroups, even if they are not going to be part of that particular subgroup.
We then pick the right group, i.e. the group that the rank is part of, and pass that into the SyncBN APIs.

Doc rendering:

<img width="786" alt="syncbn_update" src="https://user-images.githubusercontent.com/8039770/101271959-b211ab80-373c-11eb-8b6d-d56483fd9f5d.png">

Pull Request resolved: #48897

Reviewed By: zou3519

Differential Revision: D25493181

Pulled By: rohan-varma

fbshipit-source-id: a7e93fc8cc07ec7797e5dbc356f1c3877342cfa3
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Dec 11, 2020
1 parent f10b53d commit c0a0845
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions torch/nn/modules/batchnorm.py
Expand Up @@ -434,8 +434,14 @@ class SyncBatchNorm(_BatchNorm):
>>> # With Learnable Parameters
>>> m = nn.SyncBatchNorm(100)
>>> # creating process group (optional)
>>> # process_ids is a list of int identifying rank ids.
>>> process_group = torch.distributed.new_group(process_ids)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
>>> input = torch.randn(20, 100, 35, 45, 10)
Expand Down Expand Up @@ -564,8 +570,14 @@ def convert_sync_batchnorm(cls, module, process_group=None):
>>> torch.nn.BatchNorm1d(100),
>>> ).cuda()
>>> # creating process group (optional)
>>> # process_ids is a list of int identifying rank ids.
>>> process_group = torch.distributed.new_group(process_ids)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
"""
Expand Down

0 comments on commit c0a0845

Please sign in to comment.