Skip to content

Commit

Permalink
Avoid scatter for single-device case in DDP
Browse files Browse the repository at this point in the history
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114298122

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)
  • Loading branch information
rohan-varma committed Oct 14, 2020
1 parent f2e5ae4 commit 94c9299
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 4 deletions.
40 changes: 38 additions & 2 deletions torch/nn/parallel/distributed.py
Expand Up @@ -16,7 +16,7 @@
from torch.distributed.distributed_c10d import ReduceOp
from ..modules import Module
from .replicate import replicate
from .scatter_gather import scatter_kwargs, gather
from .scatter_gather import scatter_kwargs, gather, is_namedtuple
from .parallel_apply import parallel_apply
from torch._utils import _get_device_index, _get_all_device_indices

Expand Down Expand Up @@ -660,10 +660,11 @@ def forward(self, *inputs, **kwargs):
self._check_global_requires_backward_grad_sync(is_joined_rank=False)

if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
output = self.module(*inputs[0], **kwargs[0])
else:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
Expand All @@ -688,6 +689,41 @@ def forward(self, *inputs, **kwargs):
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

def _recursive_to(self, inputs, target_gpu):
r"""
Recursively moves input to the target_gpu.
"""
def to_map(obj):
if isinstance(obj, torch.Tensor):
return (obj.to(target_gpu), )
if is_namedtuple(obj):
return list(type(obj)(*args) for args in zip(*map(to_map, obj)))
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(to_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(to_map, obj))))
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(to_map, obj.items()))))
return [obj]

# Avoid reference cycle
try:
res = to_map(inputs)
finally:
to_map = None
return res

def to_kwargs(self, inputs, kwargs, device_id):
inputs = self._recursive_to(inputs, device_id) if inputs else []
kwargs = self._recursive_to(kwargs, device_id) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs

def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

Expand Down
4 changes: 2 additions & 2 deletions torch/nn/parallel/scatter_gather.py
@@ -1,7 +1,7 @@
import torch
from ._functions import Scatter, Gather

def _is_namedtuple(obj):
def is_namedtuple(obj):
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
Expand All @@ -17,7 +17,7 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
if is_namedtuple(obj):
return list(type(obj)(*args) for args in zip(*map(scatter_map, obj)))
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
Expand Down
50 changes: 50 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -3781,6 +3781,56 @@ def forward(self, x):
else:
ddp(inp).sum().backward()

@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_device(self):
m = nn.Linear(10, 10).to(self.rank)
expected_len = 2

class TensorWrapper:
__slots__ = ['t']

def __init__(self, t):
self.t = t

class ToyModel(torch.nn.Module):
def __init__(_self): # noqa: B902
super().__init__()
_self.lin = nn.Linear(10, 10, bias=False)

def forward(_self, x, expected_type): # noqa: B902
# Similar to scatter, the recursive to in the single-device
# case does not move tensors if they are in a custom type.
self.assertTrue(isinstance(x, expected_type))
if expected_type == TensorWrapper:
self.assertEqual(str(x.t.device), "cpu")
x.t = x.t.to(self.rank)
return _self.lin(x.t)
else:
self.assertTrue(len(x), expected_len)
self.assertTrue(x[0].device == x[1].device)
self.assertEqual(x[0].device.index, self.rank)
t = x[0] + x[1]
return _self.lin(t)

model = torch.nn.parallel.DistributedDataParallel(
ToyModel().to(self.rank), device_ids=[self.rank]
)
# CPU input, should be moved to the proper device before call to
# forward.
inp = tuple(torch.randn(10, 10) for _ in range(expected_len))
model(inp, tuple)
# List CPU input, should be moved to proper device before call to
# forward.
inp = [torch.randn(10, 10) for _ in range(expected_len)]
model(inp, list)
# Custom type containing tensor. The type is maintained, but the
# device is not propagated (which is what happens with scatter too)
inp = TensorWrapper(torch.randn(10, 10))
model(inp, TensorWrapper)

@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
Expand Down

0 comments on commit 94c9299

Please sign in to comment.