Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT

try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False


skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

BACKEND = os.environ["BACKEND"]
TEMP_DIR = os.environ["TEMP_DIR"]
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
Expand Down Expand Up @@ -1528,6 +1537,19 @@ def test_DistributedDataParallel_SyncBatchNorm(self):
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))

@skipIfNoTorchVision
def test_SyncBatchNorm_process_group(self):
# When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
# it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
# is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).

process_ids = 0
process_group = torch.distributed.new_group([process_ids])
res50_model = torchvision.models.resnet50()
res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(res50_model), process_group)
process_group_sync = res50_model_sync.layer1[0].bn1.process_group
self.assertEqual(process_group_sync, process_group)

if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"]

Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,6 @@ def convert_sync_batchnorm(cls, module, process_group=None):
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, cls.convert_sync_batchnorm(child))
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
del module
return module_output