diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index e76e307d36a6..48e58d637ea6 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -434,8 +434,14 @@ class SyncBatchNorm(_BatchNorm): >>> # With Learnable Parameters >>> m = nn.SyncBatchNorm(100) >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> # Without Learnable Parameters >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) >>> input = torch.randn(20, 100, 35, 45, 10) @@ -564,8 +570,14 @@ def convert_sync_batchnorm(cls, module, process_group=None): >>> torch.nn.BatchNorm1d(100), >>> ).cuda() >>> # creating process group (optional) - >>> # process_ids is a list of int identifying rank ids. - >>> process_group = torch.distributed.new_group(process_ids) + >>> # ranks is a list of int identifying rank ids. + >>> ranks = list(range(8)) + >>> r1, r2 = ranks[:4], ranks[4:] + >>> # Note: every rank calls into new_group for every + >>> # process group created, even if that rank is not + >>> # part of the group. + >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] + >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) """