Skip to content

Commit

Permalink
Avoid scatter for single-device case in DDP
Browse files Browse the repository at this point in the history
Pull Request resolved: #46304

In the case that a single process operates only on one GPU, we can
avoid this scatter and instead replace it with a recursive version of `to`
which transfers the input tensors to the correct device.

The implementation of `_recursive_to` is modeled after `scatter` in https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py, in order to keep parity with the previous conventions (i.e. custom types not having their tensors moved).
ghstack-source-id: 114861410

Differential Revision: [D24296377](https://our.internmc.facebook.com/intern/diff/D24296377/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D24296377/)!
  • Loading branch information
rohan-varma committed Oct 21, 2020
1 parent 8328630 commit 4ef551d
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 11 deletions.
40 changes: 38 additions & 2 deletions torch/nn/parallel/distributed.py
Expand Up @@ -16,7 +16,7 @@
from torch.distributed.distributed_c10d import ReduceOp
from ..modules import Module
from .replicate import replicate
from .scatter_gather import scatter_kwargs, gather
from .scatter_gather import scatter_kwargs, gather, is_namedtuple
from .parallel_apply import parallel_apply
from torch._utils import _get_device_index, _get_all_device_indices

Expand Down Expand Up @@ -666,10 +666,11 @@ def forward(self, *inputs, **kwargs):
self._check_global_requires_backward_grad_sync(is_joined_rank=False)

if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
output = self.module(*inputs[0], **kwargs[0])
else:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
Expand All @@ -694,6 +695,41 @@ def forward(self, *inputs, **kwargs):
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

def _recursive_to(self, inputs, target_gpu):
r"""
Recursively moves input to the target_gpu.
"""
def to_map(obj):
if isinstance(obj, torch.Tensor):
return (obj.to(target_gpu), )
if is_namedtuple(obj):
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(to_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(to_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
return [obj]

# Avoid reference cycle
try:
res = to_map(inputs)
finally:
to_map = None
return res

def to_kwargs(self, inputs, kwargs, device_id):
inputs = self._recursive_to(inputs, device_id) if inputs else []
kwargs = self._recursive_to(kwargs, device_id) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs

def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

Expand Down
4 changes: 2 additions & 2 deletions torch/nn/parallel/scatter_gather.py
@@ -1,7 +1,7 @@
import torch
from ._functions import Scatter, Gather

def _is_namedtuple(obj):
def is_namedtuple(obj):
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
Expand All @@ -17,7 +17,7 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
if is_namedtuple(obj):
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
Expand Down
120 changes: 113 additions & 7 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -67,6 +67,13 @@ def __eq__(self, other):
[1, 2, True, "string", [4, 5, "nested"]],
]

# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
EXPECTED_FIELDS = ("a", "b")
TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS)

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

Expand Down Expand Up @@ -3810,16 +3817,115 @@ def forward(self, x):
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_namedtuple(self):
expected_fields = ("a", "b")
TestNamedTupleInput_0 = namedtuple("NamedTuple", expected_fields)
def test_ddp_device(self):
m = nn.Linear(10, 10).to(self.rank)
expected_len = 2

class TensorWrapper:
__slots__ = ['t', 'moved_to_gpu']

def __init__(self, t):
self.t = t
self.moved_to_gpu = False

# Handlers for specific types of validation we want to do based on
# the input type.

def tuple_and_list_validator(x):
self.assertTrue(len(x), expected_len)
self.assertEqual(1, len(set(t.device for t in x)))
self.assertEqual(x[0].device.index, self.rank)
return x[0] + x[1]

def namedtuple_validator(x):
self.assertEqual(x._fields, EXPECTED_FIELDS)
self.assertEqual(x.a.device.index, x.b.device.index)
self.assertEqual(x.a.device.index, self.rank)
return x.a + x.b

def custom_type_validator(x):
self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu"))
x.t = x.t.to(self.rank)
x.moved_to_gpu = True
return x.t

def dict_validator(x):
self.assertTrue(EXPECTED_FIELDS[0] in x.keys())
self.assertTrue(EXPECTED_FIELDS[1] in x.keys())
self.assertEqual(1, len(set(t.device for t in x.values())))
self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank)
return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]]

validators = {
TensorWrapper: custom_type_validator,
tuple: tuple_and_list_validator,
list: tuple_and_list_validator,
TestNamedTupleInput_0: namedtuple_validator,
TestNamedTupleInput_1: namedtuple_validator,
dict: dict_validator,
}

class ToyModel(torch.nn.Module):
def __init__(_self): # noqa: B902
super().__init__()
_self.lin = nn.Linear(10, 10, bias=False)

def forward(_self, x, expected_type): # noqa: B902
# Similar to scatter, the recursive to in the single-device
# case does not move tensors if they are in a custom type.
self.assertTrue(isinstance(x, expected_type))
fwd_tensor = validators[expected_type](x)
return _self.lin(fwd_tensor)

model = torch.nn.parallel.DistributedDataParallel(
ToyModel().to(self.rank), device_ids=[self.rank]
)

def train_iter(inp, input_type):
for _ in range(4):
out = model(inp, input_type)
out.sum().backward()

# CPU tuple input, should be moved to the proper device before call
# to forward.
inp = tuple(torch.randn(10, 10) for _ in range(expected_len))
train_iter(inp, tuple)

# List CPU input, should be moved to proper device before call to
# forward.
inp = [torch.randn(10, 10) for _ in range(expected_len)]
train_iter(inp, list)
# Custom type containing tensor. The type is maintained, but the
# device is not propagated (which is what happens with scatter too)
inp = TensorWrapper(torch.randn(10, 10))
train_iter(inp, TensorWrapper)
# NamedTuple input. The type should be maintained and tensor inputs
# should be moved to the correct device as in scatter.
batch = 5
dim = 10
a = torch.rand(batch, dim)
b = torch.rand(batch, dim)

inp = TestNamedTupleInput_0(a, b)
train_iter(inp, type(inp))

inp = TestNamedTupleInput_1(a, b)
train_iter(inp, type(inp))

# dictionary input.
inp = {
EXPECTED_FIELDS[0]: a,
EXPECTED_FIELDS[1]: b,
}
train_iter(inp, type(inp))

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor
@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_namedtuple(self):
batch = 5
dim = 10

a = torch.rand(batch, dim, device=self.rank)
b = torch.rand(batch, dim, device=self.rank)
Expand All @@ -3835,7 +3941,7 @@ def forward(_self, input, expected_type): # noqa
isinstance(input, expected_type),
f"Expected type {expected_type} but got {type(input)}",
)
self.assertEqual(input._fields, expected_fields)
self.assertEqual(input._fields, EXPECTED_FIELDS)
self.assertEqual(a, input.a)
self.assertEqual(b, input.b)
return _self.lin(torch.mul(input.a, input.b))
Expand Down

0 comments on commit 4ef551d

Please sign in to comment.