From 4ef551df0f1e47ebc61fc78308376337fb8701df Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 21 Oct 2020 14:46:51 -0700 Subject: [PATCH] Avoid scatter for single-device case in DDP Pull Request resolved: https://github.com/pytorch/pytorch/pull/46304 In the case that a single process operates only on one GPU, we can avoid this scatter and instead replace it with a recursive version of `to` which transfers the input tensors to the correct device. The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved). ghstack-source-id: 114861410 Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D24296377/)! --- torch/nn/parallel/distributed.py | 40 +++++- torch/nn/parallel/scatter_gather.py | 4 +- .../_internal/distributed/distributed_test.py | 120 +++++++++++++++++- 3 files changed, 153 insertions(+), 11 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 1de0d033674b..0fceb2137a3b 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -16,7 +16,7 @@ from torch.distributed.distributed_c10d import ReduceOp from ..modules import Module from .replicate import replicate -from .scatter_gather import scatter_kwargs, gather +from .scatter_gather import scatter_kwargs, gather, is_namedtuple from .parallel_apply import parallel_apply from torch._utils import _get_device_index, _get_all_device_indices @@ -666,10 +666,11 @@ def forward(self, *inputs, **kwargs): self._check_global_requires_backward_grad_sync(is_joined_rank=False) if self.device_ids: - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: + inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0]) output = self.module(*inputs[0], **kwargs[0]) else: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: @@ -694,6 +695,41 @@ def forward(self, *inputs, **kwargs): def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + def _recursive_to(self, inputs, target_gpu): + r""" + Recursively moves input to the target_gpu. + """ + def to_map(obj): + if isinstance(obj, torch.Tensor): + return (obj.to(target_gpu), ) + if is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(to_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return [list(i) for i in zip(*map(to_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None + return res + + def to_kwargs(self, inputs, kwargs, device_id): + inputs = self._recursive_to(inputs, device_id) if inputs else [] + kwargs = self._recursive_to(kwargs, device_id) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs + def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index cfb9a70a10c9..771fbba68f02 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,7 +1,7 @@ import torch from ._functions import Scatter, Gather -def _is_namedtuple(obj): +def is_namedtuple(obj): # Check if type was created from collections.namedtuple or a typing.NamedTuple. return ( isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") @@ -17,7 +17,7 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) - if _is_namedtuple(obj): + if is_namedtuple(obj): return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f0e06434bba9..a3a5e43c7054 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -67,6 +67,13 @@ def __eq__(self, other): [1, 2, True, "string", [4, 5, "nested"]], ] +# Dummy NamedTuple data structures to test DDP support for NamedTuple types. +EXPECTED_FIELDS = ("a", "b") +TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS) + +class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") @@ -3810,16 +3817,115 @@ def forward(self, x): @require_backends_available({"gloo", "nccl"}) @skip_if_lt_x_gpu(2) @skip_if_rocm - def test_ddp_namedtuple(self): - expected_fields = ("a", "b") - TestNamedTupleInput_0 = namedtuple("NamedTuple", expected_fields) + def test_ddp_device(self): + m = nn.Linear(10, 10).to(self.rank) + expected_len = 2 + + class TensorWrapper: + __slots__ = ['t', 'moved_to_gpu'] + + def __init__(self, t): + self.t = t + self.moved_to_gpu = False + + # Handlers for specific types of validation we want to do based on + # the input type. + + def tuple_and_list_validator(x): + self.assertTrue(len(x), expected_len) + self.assertEqual(1, len(set(t.device for t in x))) + self.assertEqual(x[0].device.index, self.rank) + return x[0] + x[1] + + def namedtuple_validator(x): + self.assertEqual(x._fields, EXPECTED_FIELDS) + self.assertEqual(x.a.device.index, x.b.device.index) + self.assertEqual(x.a.device.index, self.rank) + return x.a + x.b + + def custom_type_validator(x): + self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu")) + x.t = x.t.to(self.rank) + x.moved_to_gpu = True + return x.t + + def dict_validator(x): + self.assertTrue(EXPECTED_FIELDS[0] in x.keys()) + self.assertTrue(EXPECTED_FIELDS[1] in x.keys()) + self.assertEqual(1, len(set(t.device for t in x.values()))) + self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank) + return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]] + + validators = { + TensorWrapper: custom_type_validator, + tuple: tuple_and_list_validator, + list: tuple_and_list_validator, + TestNamedTupleInput_0: namedtuple_validator, + TestNamedTupleInput_1: namedtuple_validator, + dict: dict_validator, + } + + class ToyModel(torch.nn.Module): + def __init__(_self): # noqa: B902 + super().__init__() + _self.lin = nn.Linear(10, 10, bias=False) + def forward(_self, x, expected_type): # noqa: B902 + # Similar to scatter, the recursive to in the single-device + # case does not move tensors if they are in a custom type. + self.assertTrue(isinstance(x, expected_type)) + fwd_tensor = validators[expected_type](x) + return _self.lin(fwd_tensor) + + model = torch.nn.parallel.DistributedDataParallel( + ToyModel().to(self.rank), device_ids=[self.rank] + ) + + def train_iter(inp, input_type): + for _ in range(4): + out = model(inp, input_type) + out.sum().backward() + + # CPU tuple input, should be moved to the proper device before call + # to forward. + inp = tuple(torch.randn(10, 10) for _ in range(expected_len)) + train_iter(inp, tuple) + + # List CPU input, should be moved to proper device before call to + # forward. + inp = [torch.randn(10, 10) for _ in range(expected_len)] + train_iter(inp, list) + # Custom type containing tensor. The type is maintained, but the + # device is not propagated (which is what happens with scatter too) + inp = TensorWrapper(torch.randn(10, 10)) + train_iter(inp, TensorWrapper) + # NamedTuple input. The type should be maintained and tensor inputs + # should be moved to the correct device as in scatter. batch = 5 dim = 10 + a = torch.rand(batch, dim) + b = torch.rand(batch, dim) + + inp = TestNamedTupleInput_0(a, b) + train_iter(inp, type(inp)) + + inp = TestNamedTupleInput_1(a, b) + train_iter(inp, type(inp)) + + # dictionary input. + inp = { + EXPECTED_FIELDS[0]: a, + EXPECTED_FIELDS[1]: b, + } + train_iter(inp, type(inp)) - class TestNamedTupleInput_1(NamedTuple): - a: torch.tensor - b: torch.tensor + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_namedtuple(self): + batch = 5 + dim = 10 a = torch.rand(batch, dim, device=self.rank) b = torch.rand(batch, dim, device=self.rank) @@ -3835,7 +3941,7 @@ def forward(_self, input, expected_type): # noqa isinstance(input, expected_type), f"Expected type {expected_type} but got {type(input)}", ) - self.assertEqual(input._fields, expected_fields) + self.assertEqual(input._fields, EXPECTED_FIELDS) self.assertEqual(a, input.a) self.assertEqual(b, input.b) return _self.lin(torch.mul(input.a, input.b))