diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 894175e8c12f..1adc46da7d4e 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -16,7 +16,7 @@ from torch.distributed.distributed_c10d import ReduceOp from ..modules import Module from .replicate import replicate -from .scatter_gather import scatter_kwargs, gather +from .scatter_gather import scatter_kwargs, gather, is_namedtuple from .parallel_apply import parallel_apply from torch._utils import _get_device_index, _get_all_device_indices @@ -660,10 +660,11 @@ def forward(self, *inputs, **kwargs): self._check_global_requires_backward_grad_sync(is_joined_rank=False) if self.device_ids: - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: + inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0]) output = self.module(*inputs[0], **kwargs[0]) else: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: @@ -688,6 +689,41 @@ def forward(self, *inputs, **kwargs): def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + def _recursive_to(self, inputs, target_gpu): + r""" + Recursively moves input to the target_gpu. + """ + def to_map(obj): + if isinstance(obj, torch.Tensor): + return (obj.to(target_gpu), ) + if is_namedtuple(obj): + return list(type(obj)(*args) for args in zip(*map(to_map, obj))) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return list(map(list, zip(*map(to_map, obj)))) + if isinstance(obj, dict) and len(obj) > 0: + return list(map(type(obj), zip(*map(to_map, obj.items())))) + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None + return res + + def to_kwargs(self, inputs, kwargs, device_id): + inputs = self._recursive_to(inputs, device_id) if inputs else [] + kwargs = self._recursive_to(kwargs, device_id) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs + def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index a90d85f037c3..fe5b61463f2d 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,7 +1,7 @@ import torch from ._functions import Scatter, Gather -def _is_namedtuple(obj): +def is_namedtuple(obj): # Check if type was created from collections.namedtuple or a typing.NamedTuple. return ( isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") @@ -17,7 +17,7 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) - if _is_namedtuple(obj): + if is_namedtuple(obj): return list(type(obj)(*args) for args in zip(*map(scatter_map, obj))) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f1c0e2a3a4a8..7de443cd6cd6 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -67,6 +67,13 @@ def __eq__(self, other): [1, 2, True, "string", [4, 5, "nested"]], ] +# Dummy NamedTuple data structures to test DDP support for NamedTuple types. +EXPECTED_FIELDS = ("a", "b") +TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS) + +class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") @@ -3785,16 +3792,90 @@ def forward(self, x): @require_backends_available({"gloo", "nccl"}) @skip_if_lt_x_gpu(2) @skip_if_rocm - def test_ddp_namedtuple(self): - expected_fields = ("a", "b") - TestNamedTupleInput_0 = namedtuple("NamedTuple", expected_fields) + def test_ddp_device(self): + m = nn.Linear(10, 10).to(self.rank) + expected_len = 2 + + class TensorWrapper: + __slots__ = ['t'] + + def __init__(self, t): + self.t = t + + # Handlers for specific types of validation we want to do based on + # the input type. + + def tuple_and_list_validator(x): + self.assertTrue(len(x), expected_len) + self.assertEqual(1, len(set(t.device for t in x))) + self.assertEqual(x[0].device.index, self.rank) + return x[0] + x[1] + + def namedtuple_validator(x): + self.assertEqual(x._fields, EXPECTED_FIELDS) + self.assertEqual(x.a.device.index, x.b.device.index) + self.assertEqual(x.a.device.index, self.rank) + return x.a + x.b + + def custom_type_validator(x): + self.assertEqual(str(x.t.device), "cpu") + x.t = x.t.to(self.rank) + return x.t + + validators = { + TensorWrapper: custom_type_validator, + tuple: tuple_and_list_validator, + list: tuple_and_list_validator, + TestNamedTupleInput_0: namedtuple_validator, + TestNamedTupleInput_1: namedtuple_validator, + } + + class ToyModel(torch.nn.Module): + def __init__(_self): # noqa: B902 + super().__init__() + _self.lin = nn.Linear(10, 10, bias=False) + + def forward(_self, x, expected_type): # noqa: B902 + # Similar to scatter, the recursive to in the single-device + # case does not move tensors if they are in a custom type. + self.assertTrue(isinstance(x, expected_type)) + fwd_tensor = validators[expected_type](x) + return _self.lin(fwd_tensor) + + model = torch.nn.parallel.DistributedDataParallel( + ToyModel().to(self.rank), device_ids=[self.rank] + ) + # CPU tuple input, should be moved to the proper device before call + # to forward. + inp = tuple(torch.randn(10, 10) for _ in range(expected_len)) + model(inp, tuple) + # List CPU input, should be moved to proper device before call to + # forward. + inp = [torch.randn(10, 10) for _ in range(expected_len)] + model(inp, list) + # Custom type containing tensor. The type is maintained, but the + # device is not propagated (which is what happens with scatter too) + inp = TensorWrapper(torch.randn(10, 10)) + model(inp, TensorWrapper) batch = 5 dim = 10 + a = torch.rand(batch, dim) + b = torch.rand(batch, dim) + + inp = TestNamedTupleInput_0(a, b) + model(inp, type(inp)) - class TestNamedTupleInput_1(NamedTuple): - a: torch.tensor - b: torch.tensor + inp = TestNamedTupleInput_1(a, b) + model(inp, type(inp)) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_namedtuple(self): + batch = 5 + dim = 10 a = torch.rand(batch, dim, device=self.rank) b = torch.rand(batch, dim, device=self.rank) @@ -3810,7 +3891,7 @@ def forward(_self, input, expected_type): # noqa isinstance(input, expected_type), f"Expected type {expected_type} but got {type(input)}", ) - self.assertEqual(input._fields, expected_fields) + self.assertEqual(input._fields, EXPECTED_FIELDS) self.assertEqual(a, input.a) self.assertEqual(b, input.b) return _self.lin(torch.mul(input.a, input.b))