diff --git a/test/common.py b/test/common.py index fe155cd82ec80..8beca56a09083 100644 --- a/test/common.py +++ b/test/common.py @@ -37,6 +37,8 @@ def run_tests(): unittest.main(argv=UNITTEST_ARGS) +PY3 = sys.version_info > (3, 0) + IS_WINDOWS = sys.platform == "win32" TEST_NUMPY = True diff --git a/test/test_nn.py b/test/test_nn.py index 7ac837eaa0fca..227f8a9b375f0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -30,7 +30,7 @@ module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ TEST_CUDNN_VERSION, loss_reference_fns, get_size_average, get_weight from common import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, \ - TEST_SCIPY, download_file, IS_WINDOWS + TEST_SCIPY, download_file, IS_WINDOWS, PY3 if TEST_SCIPY: from scipy import stats @@ -1710,6 +1710,29 @@ def test_data_parallel_small_back(self): out = dp.data_parallel(l, i, (0, 1)) self.assertEqual(out, l(i)) + @unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported") + def test_data_parallel_model_no_refcycles(self): + # Python 2.7 will create reference cycles with the following + # Module on multiple GPUs, but Python 3 shouldn't unless + # there are refcycles on the PyTorch side (or the defined module) + import gc + + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x): + return self.linear(x) + + gc.collect() + model = nn.DataParallel(Model().cuda()) + data = Variable(torch.randn(1).cuda()) + model(data) + + refcycles = gc.collect() + self.assertEqual(refcycles, 0) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_no_grad(self): test = self diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 1fb07c20c95b6..a3e73568a752a 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -22,7 +22,15 @@ def scatter_map(obj): return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] - return scatter_map(inputs) + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): @@ -50,4 +58,10 @@ def gather_map(outputs): if out is None: return None return type(out)(map(gather_map, zip(*outputs))) - return gather_map(outputs) + + # Recursive function calls like this create reference cycles. + # Setting the function to None clears the refcycle. + try: + return gather_map(outputs) + finally: + gather_map = None