diff --git a/test/test_cuda.py b/test/test_cuda.py index e3d94671e997..498fd199066f 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1,6 +1,7 @@ import collections import io import tempfile +from typing import NamedTuple import unittest import sys from itertools import repeat, chain, product @@ -14,6 +15,7 @@ import torch.cuda import torch.cuda.comm as comm from torch import multiprocessing as mp +from torch.nn.parallel import scatter_gather from torch._six import inf, nan, container_abcs from test_torch import AbstractTestCases @@ -3134,6 +3136,48 @@ def test_matmul_device_mismatch(self): with self.assertRaisesRegex(RuntimeError, "expected (it|them) to be on GPU"): torch.addmm(s, m1, m2) + @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") + def test_scatter_namedtuple(self): + # tests ability to scatter namedtuples and retrieve a list where each + # element is of the expected namedtuple type. + fields = ("a", "b") + TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", fields) + num_gpus = torch.cuda.device_count() + a = torch.rand(num_gpus * 2, device=0) + b = torch.rand(num_gpus * 2, device=0) + a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + + inp = TestNamedTupleInput_0(a, b) + target_gpus = [torch.device(i) for i in range(num_gpus)] + scatter_out = scatter_gather.scatter(inp, target_gpus) + + for i, x in enumerate(scatter_out): + self.assertTrue(isinstance(x, type(inp))) + self.assertEqual(x._fields, fields) + expected_a = a_tensors_for_gpu[i] + expected_b = b_tensors_for_gpu[i] + self.assertEqual(expected_a, x.a) + self.assertEqual(expected_b, x.b) + + class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor + + a = torch.rand(num_gpus * 2, device=0) + b = torch.rand(num_gpus * 2, device=0) + a_tensors_for_gpu = [a[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + b_tensors_for_gpu = [b[2 * i : 2 * i + 2].to(i) for i in range(num_gpus)] + inp = TestNamedTupleInput_1(a, b) + + scatter_out = scatter_gather.scatter(inp, target_gpus) + for i, x in enumerate(scatter_out): + self.assertTrue(isinstance(x, type(inp))) + self.assertEqual(x._fields, fields) + expected_a = a_tensors_for_gpu[i] + expected_b = b_tensors_for_gpu[i] + self.assertEqual(expected_a, x.a) + self.assertEqual(expected_b, x.b) if __name__ == '__main__': run_tests() diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 1635d40e29e8..a90d85f037c3 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,6 +1,12 @@ import torch from ._functions import Scatter, Gather +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + def scatter(inputs, target_gpus, dim=0): r""" @@ -11,6 +17,8 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) + if _is_namedtuple(obj): + return list(type(obj)(*args) for args in zip(*map(scatter_map, obj))) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 5d43882ff024..01cddee92365 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1,4 +1,5 @@ import copy +from collections import namedtuple import itertools import random import math @@ -3423,3 +3424,47 @@ def forward(self, x): # Synchronize since we run multiple iterations of this test, to # isolate failure hangs. torch.cuda.synchronize(device=self.rank) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_namedtuple(self): + expected_fields = ("a", "b") + TestNamedTupleInput_0 = namedtuple("NamedTuple", expected_fields) + + batch = 5 + dim = 10 + + class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor + + a = torch.rand(batch, dim, device=self.rank) + b = torch.rand(batch, dim, device=self.rank) + + class NamedTupleModule(torch.nn.Module): + def __init__(_self): # noqa + super().__init__() + _self.lin = nn.Linear(10, 1) + + def forward(_self, input, expected_type): # noqa + # Without NamedTuple support, this would be of type tuple. + self.assertTrue( + isinstance(input, expected_type), + f"Expected type {expected_type} but got {type(input)}", + ) + self.assertEqual(input._fields, expected_fields) + self.assertEqual(a, input.a) + self.assertEqual(b, input.b) + return _self.lin(torch.mul(input.a, input.b)) + + model = torch.nn.parallel.DistributedDataParallel( + NamedTupleModule().cuda(self.rank), device_ids=[self.rank] + ) + inp = TestNamedTupleInput_0(a, b) + # The following would fail if DDP does not propagate NamedTuples correctly. + model(inp, type(inp)) + + inp = TestNamedTupleInput_1(a, b) + model(inp, type(inp))