From b26bfdf07509a2dada5a9054e9da80bd38454015 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Mon, 14 Sep 2020 18:43:20 -0700 Subject: [PATCH] Correctly convert namedtuples in DDP Pull Request resolved: https://github.com/pytorch/pytorch/pull/44220 Closes https://github.com/pytorch/pytorch/issues/44009 Currently if a dataloader returns objects created with a collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple. Fix this in `scatter_gather.py` to resolve the issue reported in https://github.com/pytorch/pytorch/issues/44009 ghstack-source-id: 112050439 Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/) --- test/test_cuda.py | 31 ++++++++++++++ torch/nn/parallel/scatter_gather.py | 11 +++++ .../_internal/distributed/distributed_test.py | 42 +++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/test/test_cuda.py b/test/test_cuda.py index d748361ede6a..3c83408b00a1 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1,6 +1,7 @@ import collections import io import tempfile +from typing import NamedTuple import unittest import sys from itertools import repeat, chain, product @@ -14,6 +15,7 @@ import torch.cuda import torch.cuda.comm as comm from torch import multiprocessing as mp +from torch.nn.parallel import scatter_gather from torch._six import inf, nan, container_abcs from test_torch import AbstractTestCases @@ -3049,6 +3051,35 @@ def test_matmul_device_mismatch(self): with self.assertRaisesRegex(RuntimeError, "expected (it|them) to be on GPU"): torch.addmm(s, m1, m2) + @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") + def test_scatter_namedtuple(self): + # tests ability to scatter namedtuples and retrieve a list where each + # element is of the expected namedtuple type. + TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", ("a", "b")) + inp = TestNamedTupleInput_0( + torch.rand(10, device=0), + torch.rand(10, device=0), + ) + num_gpus = torch.cuda.device_count() + target_gpus = [torch.device(i) for i in range(num_gpus)] + scatter_out = scatter_gather.scatter(inp, target_gpus) + + for x in scatter_out: + self.assertTrue(isinstance(x, type(inp))) + + class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor + + inp = TestNamedTupleInput_1( + torch.rand(10, device=0), + torch.rand(10, device=0) + ) + + scatter_out = scatter_gather.scatter(inp, target_gpus) + for x in scatter_out: + self.assertTrue(isinstance(x, type(inp))) + if __name__ == '__main__': run_tests() diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 1635d40e29e8..a659ff37de8f 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,6 +1,12 @@ import torch from ._functions import Scatter, Gather +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + def scatter(inputs, target_gpus, dim=0): r""" @@ -11,6 +17,11 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) + if _is_namedtuple(obj): + # De-listify the base case output, since scatter(1, [0]) + # would return ([1]), but for namedtuple(a=1, b=1) we would not + # want to reconstruct it as namedtuple(a=[1], b=[1]) + return [type(obj)(*map(lambda li: li[0], map(scatter_map, obj)))] if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index a09ff17f4b8f..6ceebd8c872e 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals import copy +from collections import namedtuple import fcntl import itertools import random @@ -3229,3 +3230,44 @@ def test_broadcast_object_list(self): self.assertNotEqual(objects, collectives_object_test_list) dist.broadcast_object_list(objects, src=0) self.assertEqual(objects, collectives_object_test_list) + + @require_backend({"nccl", "gloo"}) + @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) + def test_ddp_namedtuple(self): + TestNamedTupleInput_0 = namedtuple("NamedTuple", ("a", "b")) + + batch = 5 + dim = 10 + + class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor + + class NamedTupleModule(torch.nn.Module): + def __init__(_self): # noqa + super().__init__() + _self.lin = nn.Linear(10, 1) + + def forward(_self, input, expected_type): # noqa + # Without NamedTuple support, this would be of type tuple. + self.assertTrue( + isinstance(input, expected_type), + f"Expected type {expected_type} but got {type(input)}", + ) + return _self.lin(torch.mul(input.a, input.b)) + + model = torch.nn.parallel.DistributedDataParallel( + NamedTupleModule().cuda(self.rank), device_ids=[self.rank], + ) + inp = TestNamedTupleInput_0( + torch.rand(batch, dim, device=self.rank), + torch.rand(batch, dim, device=self.rank), + ) + # The following would fail if DDP does not propagate NamedTuples correctly. + model(inp, type(inp)) + + inp = TestNamedTupleInput_1( + torch.rand(batch, dim, device=self.rank), + torch.rand(batch, dim, device=self.rank), + ) + model(inp, type(inp))