Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly convert namedtuples in DDP #44220

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 31 additions & 0 deletions test/test_cuda.py
@@ -1,6 +1,7 @@
import collections
import io
import tempfile
from typing import NamedTuple
import unittest
import sys
from itertools import repeat, chain, product
Expand All @@ -14,6 +15,7 @@
import torch.cuda
import torch.cuda.comm as comm
from torch import multiprocessing as mp
from torch.nn.parallel import scatter_gather
from torch._six import inf, nan, container_abcs

from test_torch import AbstractTestCases
Expand Down Expand Up @@ -3049,6 +3051,35 @@ def test_matmul_device_mismatch(self):
with self.assertRaisesRegex(RuntimeError, "expected (it|them) to be on GPU"):
torch.addmm(s, m1, m2)

@unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
def test_scatter_namedtuple(self):
# tests ability to scatter namedtuples and retrieve a list where each
# element is of the expected namedtuple type.
TestNamedTupleInput_0 = collections.namedtuple("NamedTuple", ("a", "b"))
inp = TestNamedTupleInput_0(
torch.rand(10, device=0),
torch.rand(10, device=0),
)
num_gpus = torch.cuda.device_count()
target_gpus = [torch.device(i) for i in range(num_gpus)]
scatter_out = scatter_gather.scatter(inp, target_gpus)

for x in scatter_out:
self.assertTrue(isinstance(x, type(inp)))

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

inp = TestNamedTupleInput_1(
torch.rand(10, device=0),
torch.rand(10, device=0)
)

scatter_out = scatter_gather.scatter(inp, target_gpus)
for x in scatter_out:
self.assertTrue(isinstance(x, type(inp)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also assert that the _fields attribute is preserved correctly and also the appropriate values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adding the tests for valued actually uncovered a bug in the implementation that I just fixed.



if __name__ == '__main__':
run_tests()
176 changes: 45 additions & 131 deletions torch/testing/_internal/distributed/distributed_test.py
@@ -1,6 +1,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import namedtuple
import copy
from collections import namedtuple
import fcntl
import itertools
import random
Expand Down Expand Up @@ -2910,136 +2910,7 @@ def test_ddp_join_model_equivalence(self):
for (_, local_tensor), (_, dist_tensor) in zip(
local_model.state_dict().items(), net.module.state_dict().items()
):
with net.join():
pass
# We need a barrier since otherwise non-participating processes exit too early
# and cause a timeout.
self._barrier(timeout=60)

@require_backend({"nccl", "gloo"})
@require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"])
def test_broadcast_object_list(self):
src_rank = 0
objects = collectives_object_test_list if self.rank == src_rank else [None for _ in collectives_object_test_list]

# Single object test
single_obj_list = [objects[0]]
if self.rank != src_rank:
self.assertNotEqual(single_obj_list[0], collectives_object_test_list[0])
dist.broadcast_object_list(single_obj_list, src=0)
self.assertEqual(single_obj_list[0], collectives_object_test_list[0])

# Multiple input objects test
if self.rank != src_rank:
self.assertNotEqual(objects, collectives_object_test_list)
dist.broadcast_object_list(objects, src=0)
self.assertEqual(objects, collectives_object_test_list)

@require_backend({"nccl", "gloo"})
@require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"])
def test_ddp_namedtuple(self):
TestNamedTupleInput_0 = namedtuple("NamedTuple", ("a", "b"))

batch = 5
dim = 10

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

class NamedTupleModule(torch.nn.Module):
def __init__(_self): # noqa
super().__init__()
_self.lin = nn.Linear(10, 1)

def forward(_self, input, expected_type): # noqa
# Without NamedTuple support, this would be of type tuple.
self.assertTrue(
isinstance(input, expected_type),
f"Expected type {expected_type} but got {type(input)}",
)
return _self.lin(torch.mul(input.a, input.b))

model = torch.nn.parallel.DistributedDataParallel(
NamedTupleModule().cuda(self.rank), device_ids=[self.rank],
)
inp = TestNamedTupleInput_0(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
# The following would fail if DDP does not propagate NamedTuples correctly.
model(inp, type(inp))

inp = TestNamedTupleInput_1(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
model(inp, type(inp))


if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"]

class TestDistBackend(MultiProcessTestCase, _DistTestBase):

# Needed since MultiProcessTestCase assumes a world_size of 4, but we
# run these tests under other various world_sizes.
@property
def world_size(self):
return os.environ["WORLD_SIZE"]

@classmethod
def setUpClass(cls):
os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["WORLD_SIZE"] = str(WORLD_SIZE)
super().setUpClass()

def setUp(self):
super().setUp()
global INIT_METHOD
# initialize Barrier.
Barrier.init()
# We rely on tearDown for deleting the temporary file
# TODO: this temporary file should be deduped with the file_name
# in MultiProcessTestCase as part of supporting spawn mode for these tests.
# https://github.com/pytorch/pytorch/issues/36663
self.temporary_file = None
if INIT_METHOD.startswith("file://"):
self.temporary_file = tempfile.NamedTemporaryFile(delete=False)
INIT_METHOD = "file://{}".format(self.temporary_file.name)

# TODO: enable spawn mode https://github.com/pytorch/pytorch/issues/36663
self._fork_processes()

def tearDown(self):
super(MultiProcessTestCase, self).tearDown()
super(TestDistBackend, self).tearDown()

# Clean up temporary file if we used one.
if self.temporary_file:
try:
os.unlink(self.temporary_file.name)
except OSError as err:
# ENOENT is OK because the test is supposed to clean it up.
if err.errno != errno.ENOENT:
raise

@classmethod
def _run(cls, rank, test_name, file_name):
self = cls(test_name)
self.rank = rank
self.file_name = file_name
try:
dist.init_process_group(
init_method=INIT_METHOD,
backend=BACKEND,
world_size=int(WORLD_SIZE),
rank=self.rank
)
except RuntimeError as e:
if "recompile" in e.args[0]:
sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
self.assertEqual(local_tensor, dist_tensor)

def _run_uneven_inputs_test(
self, test_case, iteration_mapping, find_unused_params,
Expand Down Expand Up @@ -3359,3 +3230,46 @@ def test_broadcast_object_list(self):
self.assertNotEqual(objects, collectives_object_test_list)
dist.broadcast_object_list(objects, src=0)
self.assertEqual(objects, collectives_object_test_list)

@require_backend({"gloo", "nccl"})
@require_backends_available({"gloo", "nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_ddp_namedtuple(self):
TestNamedTupleInput_0 = namedtuple("NamedTuple", ("a", "b"))

batch = 5
dim = 10

class TestNamedTupleInput_1(NamedTuple):
a: torch.tensor
b: torch.tensor

class NamedTupleModule(torch.nn.Module):
def __init__(_self): # noqa
super().__init__()
_self.lin = nn.Linear(10, 1)

def forward(_self, input, expected_type): # noqa
# Without NamedTuple support, this would be of type tuple.
self.assertTrue(
isinstance(input, expected_type),
f"Expected type {expected_type} but got {type(input)}",
)
return _self.lin(torch.mul(input.a, input.b))

model = torch.nn.parallel.DistributedDataParallel(
NamedTupleModule().cuda(self.rank), device_ids=[self.rank],
)
inp = TestNamedTupleInput_0(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
# The following would fail if DDP does not propagate NamedTuples correctly.
model(inp, type(inp))

inp = TestNamedTupleInput_1(
torch.rand(batch, dim, device=self.rank),
torch.rand(batch, dim, device=self.rank),
)
model(inp, type(inp))