diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 894175e8c12f..1adc46da7d4e 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -16,7 +16,7 @@ from torch.distributed.distributed_c10d import ReduceOp from ..modules import Module from .replicate import replicate -from .scatter_gather import scatter_kwargs, gather +from .scatter_gather import scatter_kwargs, gather, is_namedtuple from .parallel_apply import parallel_apply from torch._utils import _get_device_index, _get_all_device_indices @@ -660,10 +660,11 @@ def forward(self, *inputs, **kwargs): self._check_global_requires_backward_grad_sync(is_joined_rank=False) if self.device_ids: - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: + inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0]) output = self.module(*inputs[0], **kwargs[0]) else: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) output = self.gather(outputs, self.output_device) else: @@ -688,6 +689,41 @@ def forward(self, *inputs, **kwargs): def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + def _recursive_to(self, inputs, target_gpu): + r""" + Recursively moves input to the target_gpu. + """ + def to_map(obj): + if isinstance(obj, torch.Tensor): + return (obj.to(target_gpu), ) + if is_namedtuple(obj): + return list(type(obj)(*args) for args in zip(*map(to_map, obj))) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return list(map(list, zip(*map(to_map, obj)))) + if isinstance(obj, dict) and len(obj) > 0: + return list(map(type(obj), zip(*map(to_map, obj.items())))) + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None + return res + + def to_kwargs(self, inputs, kwargs, device_id): + inputs = self._recursive_to(inputs, device_id) if inputs else [] + kwargs = self._recursive_to(kwargs, device_id) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs + def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index a90d85f037c3..fe5b61463f2d 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,7 +1,7 @@ import torch from ._functions import Scatter, Gather -def _is_namedtuple(obj): +def is_namedtuple(obj): # Check if type was created from collections.namedtuple or a typing.NamedTuple. return ( isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") @@ -17,7 +17,7 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) - if _is_namedtuple(obj): + if is_namedtuple(obj): return list(type(obj)(*args) for args in zip(*map(scatter_map, obj))) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f1c0e2a3a4a8..e2f8d29f77a8 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -3781,6 +3781,56 @@ def forward(self, x): else: ddp(inp).sum().backward() + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_device(self): + m = nn.Linear(10, 10).to(self.rank) + expected_len = 2 + + class TensorWrapper: + __slots__ = ['t'] + + def __init__(self, t): + self.t = t + + class ToyModel(torch.nn.Module): + def __init__(_self): # noqa: B902 + super().__init__() + _self.lin = nn.Linear(10, 10, bias=False) + + def forward(_self, x, expected_type): # noqa: B902 + # Similar to scatter, the recursive to in the single-device + # case does not move tensors if they are in a custom type. + self.assertTrue(isinstance(x, expected_type)) + if expected_type == TensorWrapper: + self.assertEqual(str(x.t.device), "cpu") + x.t = x.t.to(self.rank) + return _self.lin(x.t) + else: + self.assertTrue(len(x), expected_len) + self.assertTrue(x[0].device == x[1].device) + self.assertEqual(x[0].device.index, self.rank) + t = x[0] + x[1] + return _self.lin(t) + + model = torch.nn.parallel.DistributedDataParallel( + ToyModel().to(self.rank), device_ids=[self.rank] + ) + # CPU input, should be moved to the proper device before call to + # forward. + inp = tuple(torch.randn(10, 10) for _ in range(expected_len)) + model(inp, tuple) + # List CPU input, should be moved to proper device before call to + # forward. + inp = [torch.randn(10, 10) for _ in range(expected_len)] + model(inp, list) + # Custom type containing tensor. The type is maintained, but the + # device is not propagated (which is what happens with scatter too) + inp = TensorWrapper(torch.randn(10, 10)) + model(inp, TensorWrapper) + @require_backend({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"}) @skip_if_lt_x_gpu(2)