Skip to content

Commit

Permalink
Correctly convert namedtuples in DDP
Browse files Browse the repository at this point in the history
Pull Request resolved: #44220

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009
ghstack-source-id: 111478085

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)
  • Loading branch information
rohan-varma committed Sep 4, 2020
1 parent 0c01f13 commit 5055c45
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
42 changes: 42 additions & 0 deletions test/distributed/test_distributed.py
@@ -1,4 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import namedtuple
import copy
import errno
import fcntl
Expand Down Expand Up @@ -3130,6 +3131,47 @@ def test_broadcast_object_list(self):
dist.broadcast_object_list(objects, src=0)
self.assertEqual(objects, collectives_object_test_list)

@require_backend({"nccl", "gloo"})
@require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"])
def test_ddp_namedtuple(self):
TestNamedTupleInput_0 = namedtuple("NamedTuple", ("a", "b"))

batch = 5
dim = 10

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

class NamedTupleModule(torch.nn.Module):
def __init__(_self): # noqa
super().__init__()
_self.lin = nn.Linear(10, 1)

def forward(_self, input, expected_type): # noqa
# Without NamedTuple support, this would be of type tuple.
self.assertTrue(
isinstance(input, expected_type),
f"Expected type {expected_type} but got {type(input)}",
)
return _self.lin(torch.mul(input.a, input.b))

model = torch.nn.parallel.DistributedDataParallel(
NamedTupleModule().cuda(self.rank), device_ids=[self.rank],
)
inp = TestNamedTupleInput_0(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
# The following would fail if DDP does not propagate NamedTuples correctly.
model(inp, type(inp))

inp = TestNamedTupleInput_1(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
model(inp, type(inp))


if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"]
Expand Down
11 changes: 11 additions & 0 deletions torch/nn/parallel/scatter_gather.py
@@ -1,6 +1,12 @@
import torch
from ._functions import Scatter, Gather

def _is_namedtuple(obj):
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
)


def scatter(inputs, target_gpus, dim=0):
r"""
Expand All @@ -11,6 +17,11 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
# De-listify the base case output, since scatter(1, [0])
# would return ([1]), but for namedtuple(a=1, b=1) we would not
# want to reconstruct it as namedtuple(a=[1], b=[1])
return [type(obj)(*map(lambda li: li[0], map(scatter_map, obj)))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
Expand Down

0 comments on commit 5055c45

Please sign in to comment.