Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly convert namedtuples in DDP #44220

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 31 additions & 0 deletions test/test_cuda.py
@@ -1,6 +1,7 @@
import collections
import io
import tempfile
from typing import NamedTuple
import unittest
import sys
from itertools import repeat, chain, product
Expand All @@ -14,6 +15,7 @@
import torch.cuda
import torch.cuda.comm as comm
from torch import multiprocessing as mp
from torch.nn.parallel import scatter_gather
from torch._six import inf, nan, container_abcs

from test_torch import AbstractTestCases
Expand Down Expand Up @@ -3049,6 +3051,35 @@ def test_matmul_device_mismatch(self):
with self.assertRaisesRegex(RuntimeError, "expected (it|them) to be on GPU"):
torch.addmm(s, m1, m2)

@unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
def test_scatter_namedtuple(self):
# tests ability to scatter namedtuples and retrieve a list where each
# element is of the expected namedtuple type.
TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", ("a", "b"))
inp = TestNamedTupleInput_0(
torch.rand(10, device=0),
torch.rand(10, device=0),
)
num_gpus = torch.cuda.device_count()
target_gpus = [torch.device(i) for i in range(num_gpus)]
scatter_out = scatter_gather.scatter(inp, target_gpus)

for x in scatter_out:
self.assertTrue(isinstance(x, type(inp)))

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

inp = TestNamedTupleInput_1(
torch.rand(10, device=0),
torch.rand(10, device=0)
)

scatter_out = scatter_gather.scatter(inp, target_gpus)
for x in scatter_out:
self.assertTrue(isinstance(x, type(inp)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also assert that the _fields attribute is preserved correctly and also the appropriate values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adding the tests for valued actually uncovered a bug in the implementation that I just fixed.



if __name__ == '__main__':
run_tests()
11 changes: 11 additions & 0 deletions torch/nn/parallel/scatter_gather.py
@@ -1,6 +1,12 @@
import torch
from ._functions import Scatter, Gather

def _is_namedtuple(obj):
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
return (
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
)


def scatter(inputs, target_gpus, dim=0):
r"""
Expand All @@ -11,6 +17,11 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
# De-listify the base case output, since scatter(1, [0])
# would return ([1]), but for namedtuple(a=1, b=1) we would not
# want to reconstruct it as namedtuple(a=[1], b=[1])
return [type(obj)(*map(lambda li: li[0], map(scatter_map, obj)))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
Expand Down
44 changes: 44 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
from collections import namedtuple
import fcntl
import itertools
import random
Expand Down Expand Up @@ -3229,3 +3230,46 @@ def test_broadcast_object_list(self):
self.assertNotEqual(objects, collectives_object_test_list)
dist.broadcast_object_list(objects, src=0)
self.assertEqual(objects, collectives_object_test_list)

@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_namedtuple(self):
TestNamedTupleInput_0 = namedtuple("NamedTuple", ("a", "b"))

batch = 5
dim = 10

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

class NamedTupleModule(torch.nn.Module):
def __init__(_self): # noqa
super().__init__()
_self.lin = nn.Linear(10, 1)

def forward(_self, input, expected_type): # noqa
# Without NamedTuple support, this would be of type tuple.
self.assertTrue(
isinstance(input, expected_type),
f"Expected type {expected_type} but got {type(input)}",
)
return _self.lin(torch.mul(input.a, input.b))

model = torch.nn.parallel.DistributedDataParallel(
NamedTupleModule().cuda(self.rank), device_ids=[self.rank],
)
inp = TestNamedTupleInput_0(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
# The following would fail if DDP does not propagate NamedTuples correctly.
model(inp, type(inp))

inp = TestNamedTupleInput_1(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
model(inp, type(inp))