From 40f397ea6c48449aecbf638081cc33178e5669d2 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 4 Sep 2020 13:16:19 -0700 Subject: [PATCH 1/2] Correctly convert namedtuples in DDP Closes https://github.com/pytorch/pytorch/issues/44009 Currently if a dataloader returns objects created with a collections.namedtuple, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple. Fix this in `scatter_gather.py` to resolve the issue reported in https://github.com/pytorch/pytorch/issues/44009 Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/) [ghstack-poisoned] --- test/distributed/test_distributed.py | 42 ++++++++++++++++++++++++++++ torch/nn/parallel/scatter_gather.py | 11 ++++++++ 2 files changed, 53 insertions(+) diff --git a/test/distributed/test_distributed.py b/test/distributed/test_distributed.py index 2996c53cd0a8..250194b992c0 100644 --- a/test/distributed/test_distributed.py +++ b/test/distributed/test_distributed.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, division, print_function, unicode_literals +from collections import namedtuple import copy import errno import fcntl @@ -3130,6 +3131,47 @@ def test_broadcast_object_list(self): dist.broadcast_object_list(objects, src=0) self.assertEqual(objects, collectives_object_test_list) + @require_backend({"nccl", "gloo"}) + @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) + def test_ddp_namedtuple(self): + TestNamedTupleInput_0 = namedtuple("NamedTuple", ("a", "b")) + + batch = 5 + dim = 10 + + class TestNamedTupleInput_1(NamedTuple): + a: torch.tensor + b: torch.tensor + + class NamedTupleModule(torch.nn.Module): + def __init__(_self): # noqa + super().__init__() + _self.lin = nn.Linear(10, 1) + + def forward(_self, input, expected_type): # noqa + # Without NamedTuple support, this would be of type tuple. + self.assertTrue( + isinstance(input, expected_type), + f"Expected type {expected_type} but got {type(input)}", + ) + return _self.lin(torch.mul(input.a, input.b)) + + model = torch.nn.parallel.DistributedDataParallel( + NamedTupleModule().cuda(self.rank), device_ids=[self.rank], + ) + inp = TestNamedTupleInput_0( + torch.rand(batch, dim, device=self.rank), + torch.rand(batch, dim, device=self.rank), + ) + # The following would fail if DDP does not propagate NamedTuples correctly. + model(inp, type(inp)) + + inp = TestNamedTupleInput_1( + torch.rand(batch, dim, device=self.rank), + torch.rand(batch, dim, device=self.rank), + ) + model(inp, type(inp)) + if BACKEND == "gloo" or BACKEND == "nccl": WORLD_SIZE = os.environ["WORLD_SIZE"] diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 1635d40e29e8..e8cd0c73c699 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,6 +1,12 @@ import torch from ._functions import Scatter, Gather +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + def scatter(inputs, target_gpus, dim=0): r""" @@ -11,6 +17,11 @@ def scatter(inputs, target_gpus, dim=0): def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) + if _is_namedtuple(obj): + # De-listify the base case output, since scatter(1, [0]) + # would return ([1]), but for namedtuple(a=1, b=1) we would not + # want to reconstruct it as namedtuple(a=[1], b=[1]) + return [type(obj)(*map(lambda li: li[0], map(scatter_map, obj)))] if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: From 3b17cfb4f8c66badaccc1cd2f98d050584f7d9c6 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 4 Sep 2020 13:23:01 -0700 Subject: [PATCH 2/2] Update on "Correctly convert namedtuples in DDP" Closes https://github.com/pytorch/pytorch/issues/44009 Currently if a dataloader returns objects created with a `collections.namedtuple` or `typing.NamedTuple`, this will incorrectly be cast to a tuple. As a result, if we have data of these types, there can be runtime errors during the forward pass if the module is expecting a named tuple. Fix this in `scatter_gather.py` to resolve the issue reported in https://github.com/pytorch/pytorch/issues/44009 Differential Revision: [D23536752](https://our.internmc.facebook.com/intern/diff/D23536752/) [ghstack-poisoned] --- torch/nn/parallel/scatter_gather.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index e8cd0c73c699..a659ff37de8f 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -2,7 +2,7 @@ from ._functions import Scatter, Gather def _is_namedtuple(obj): - # Check if type was created from collections.namedtuple. + # Check if type was created from collections.namedtuple or a typing.NamedTuple. return ( isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") )