Skip to content

Commit

Permalink
Correctly convert namedtuples in DDP
Browse files Browse the repository at this point in the history
Pull Request resolved: #44220

Closes #44009
Currently if a dataloader returns objects created with a
collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple.

Fix this in
`scatter_gather.py` to resolve the issue reported in
#44009
ghstack-source-id: 112050439

Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/)
  • Loading branch information
rohan-varma committed Sep 15, 2020
1 parent ace81b6 commit b26bfdf
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
31 changes: 31 additions & 0 deletions test/test_cuda.py
@@ -1,6 +1,7 @@
import collections
import io
import tempfile
from typing import NamedTuple
import unittest
import sys
from itertools import repeat, chain, product
Expand All @@ -14,6 +15,7 @@
import torch.cuda
import torch.cuda.comm as comm
from torch import multiprocessing as mp
from torch.nn.parallel import scatter_gather
from torch._six import inf, nan, container_abcs

from test_torch import AbstractTestCases
Expand Down Expand Up @@ -3049,6 +3051,35 @@ def test_matmul_device_mismatch(self):
with self.assertRaisesRegex(RuntimeError, "expected (it|them) to be on GPU"):
torch.addmm(s, m1, m2)

@unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
def test_scatter_namedtuple(self):
# tests ability to scatter namedtuples and retrieve a list where each
# element is of the expected namedtuple type.
TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", ("a", "b"))
inp = TestNamedTupleInput_0(
torch.rand(10, device=0),
torch.rand(10, device=0),
)
num_gpus = torch.cuda.device_count()
target_gpus = [torch.device(i) for i in range(num_gpus)]
scatter_out = scatter_gather.scatter(inp, target_gpus)

for x in scatter_out:
self.assertTrue(isinstance(x, type(inp)))

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

inp = TestNamedTupleInput_1(
torch.rand(10, device=0),
torch.rand(10, device=0)
)

scatter_out = scatter_gather.scatter(inp, target_gpus)
for x in scatter_out:
self.assertTrue(isinstance(x, type(inp)))


if __name__ == '__main__':
run_tests()
11 changes: 11 additions & 0 deletions torch/nn/parallel/scatter_gather.py
@@ -1,6 +1,12 @@
import torch
from ._functions import Scatter, Gather

def _is_namedtuple(obj):
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
)


def scatter(inputs, target_gpus, dim=0):
r"""
Expand All @@ -11,6 +17,11 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
# De-listify the base case output, since scatter(1, [0])
# would return ([1]), but for namedtuple(a=1, b=1) we would not
# want to reconstruct it as namedtuple(a=[1], b=[1])
return [type(obj)(*map(lambda li: li[0], map(scatter_map, obj)))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
Expand Down
42 changes: 42 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
from collections import namedtuple
import fcntl
import itertools
import random
Expand Down Expand Up @@ -3229,3 +3230,44 @@ def test_broadcast_object_list(self):
self.assertNotEqual(objects, collectives_object_test_list)
dist.broadcast_object_list(objects, src=0)
self.assertEqual(objects, collectives_object_test_list)

@require_backend({"nccl", "gloo"})
@require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"])
def test_ddp_namedtuple(self):
TestNamedTupleInput_0 = namedtuple("NamedTuple", ("a", "b"))

batch = 5
dim = 10

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

class NamedTupleModule(torch.nn.Module):
def __init__(_self): # noqa
super().__init__()
_self.lin = nn.Linear(10, 1)

def forward(_self, input, expected_type): # noqa
# Without NamedTuple support, this would be of type tuple.
self.assertTrue(
isinstance(input, expected_type),
f"Expected type {expected_type} but got {type(input)}",
)
return _self.lin(torch.mul(input.a, input.b))

model = torch.nn.parallel.DistributedDataParallel(
NamedTupleModule().cuda(self.rank), device_ids=[self.rank],
)
inp = TestNamedTupleInput_0(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
# The following would fail if DDP does not propagate NamedTuples correctly.
model(inp, type(inp))

inp = TestNamedTupleInput_1(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
model(inp, type(inp))

0 comments on commit b26bfdf

Please sign in to comment.