From 075ff2325edf9e0388d3240f0fa8aa913decb220 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 5 Oct 2020 17:16:04 -0700 Subject: [PATCH] adding complex support for distributed functions and . fix #45760 ghstack-source-id: ff291d0f75451c20b1cdfd7d93738fb252a60fc9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45879 --- torch/distributed/distributed_c10d.py | 15 ++ .../_internal/distributed/distributed_test.py | 184 +++++++++++++----- 2 files changed, 148 insertions(+), 51 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index ae4338cd28fc..e34fadcdca4d 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -892,6 +892,8 @@ def all_reduce_multigpu(tensor_list, if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -934,6 +936,8 @@ def all_reduce(tensor, if _rank_not_in_group(group): return + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -985,6 +989,8 @@ def all_reduce_coalesced(tensors, if _rank_not_in_group(group): return + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + opts = AllreduceCoalescedOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1149,6 +1155,9 @@ def all_gather_multigpu(output_tensor_lists, if _rank_not_in_group(group): return + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather(output_tensor_lists, input_tensor_list) @@ -1414,6 +1423,9 @@ def all_gather(tensor_list, if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather([tensor_list], [tensor]) @@ -1480,6 +1492,9 @@ def all_gather_coalesced(output_tensor_lists, for output_tensor_list in output_tensor_lists: _check_tensor_list(output_tensor_list, "output_tensor_lists") + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather_coalesced( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 01cddee92365..57e5b8367556 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -216,10 +216,10 @@ def _build_tensor(size, value=None, dtype=torch.float): return torch.empty(size, size, size, dtype=dtype).fill_(value) -def _build_multidim_tensor(dim, dim_size, value=None): +def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float): if value is None: value = size - return torch.FloatTensor(size=[dim_size for _ in range(dim)]).fill_(value) + return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value) class Barrier(object): @@ -1010,20 +1010,17 @@ def _test_all_reduce_helper( expected_value, cuda=False, rank_to_GPU=None, + dtype=torch.float, ): for src in group: - if rank == src: - tensor = _build_tensor(src + 1).fill_(master_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) - else: - tensor = _build_tensor(src + 1).fill_(worker_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + curr_value = master_value if rank == src else worker_value + + tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + dist.all_reduce(tensor, op, group_id) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) + self.assertEqual(tensor, expected_tensor) self._barrier() @@ -1060,6 +1057,41 @@ def test_all_reduce_sum_cuda(self): rank_to_GPU, ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_sum_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + dtype=torch.cfloat, + ) + + @unittest.skipIf( + BACKEND != "gloo", + "Only Gloo backend will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + True, + rank_to_GPU, + dtype=torch.cfloat, + ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_product(self): group, group_id, rank = self._init_global_test() @@ -1198,9 +1230,10 @@ def test_sparse_all_reduce_sum_cuda(self): @staticmethod def _all_reduce_coalesced_sum_test_cases(group_size): return ( - [2, 3], - [10, 11], - [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1)] + [2, 3, complex(2, 3)], + [10, 11, complex(10, 11)], + [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1), complex(2, 3) + complex(10, 11) * (group_size - 1)], + [torch.float, torch.float, torch.cfloat], ) @staticmethod @@ -1208,7 +1241,8 @@ def _all_reduce_coalesced_product_test_cases(group_size): return ( [1, 2], [3, 4], - [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)] + [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)], + [torch.float, torch.float], ) @staticmethod @@ -1216,7 +1250,8 @@ def _all_reduce_coalesced_min_test_cases(group_size): return ( [1, 4], [2, 3], - [1, 3] + [1, 3], + [torch.float, torch.float], ) @staticmethod @@ -1224,7 +1259,8 @@ def _all_reduce_coalesced_max_test_cases(group_size): return ( [1, 4], [2, 3], - [2, 4] + [2, 4], + [torch.float, torch.float], ) def _test_all_reduce_coalesced_helper( @@ -1243,22 +1279,24 @@ def _test_all_reduce_coalesced_helper( dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases }[op] - master_values, worker_values, expected_values = test_case_func(len(group)) + master_values, worker_values, expected_values, dtypes = test_case_func(len(group)) for src in group: + curr_values = master_values if rank == src else worker_values tensors = [ - _build_tensor(src + 1, val) - for val in (master_values if rank == src else worker_values) + _build_tensor(src + 1, val, dtype=dtype) + for dtype, val in zip(dtypes, curr_values) ] if cuda: tensors = list(map(tensors, lambda t: t.cuda(rank_to_GPU[rank][0]))) dist.all_reduce_coalesced(tensors, op, group_id) + expected_tensors = [ + _build_tensor(src + 1, expected_value, dtype=dtype) + for dtype, expected_value in zip(dtypes, expected_values) + ] self.assertEqual( tensors, - [ - _build_tensor(src + 1, expected_value) - for expected_value in expected_values - ] + expected_tensors ) self._barrier() @@ -1519,17 +1557,18 @@ def test_gather_full_group(self): # ALL GATHER def _test_all_gather_helper( - self, group, group_id, rank, cuda=False, rank_to_GPU=None + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float ): + for dest in group: - tensor = _build_tensor(dest + 1, rank) - tensors = [_build_tensor(dest + 1, -1) for i in group] + tensor = _build_tensor(dest + 1, rank, dtype=dtype) + tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group] if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] dist.all_gather(tensors, tensor, group_id) - expected_tensors = [_build_tensor(dest + 1, i) for i in group] + expected_tensors = [_build_tensor(dest + 1, i, dtype=dtype) for i in group] for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2) @@ -1548,6 +1587,19 @@ def test_all_gather_cuda(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_gather_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat) + + @unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA all gather") + @unittest.skipIf(BACKEND == "nccl", "CUDA all gather skipped for NCCL") + @skip_if_no_gpu + def test_all_gather_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_gather_group(self): @@ -1576,7 +1628,7 @@ def _run_all_gather_coalesced_and_verify( return True def _test_all_gather_coalesced_helper( - self, group, group_id, rank + self, group, group_id, rank, dtype=torch.float ): # TODO: Instead we should probably go through _rank_not_in_group # mechanism to disable sending tensors @@ -1586,13 +1638,16 @@ def _test_all_gather_coalesced_helper( # [1], [2x2], [3x3x3] ... to be sent in one batch input_tensors = [ _build_multidim_tensor( - tensor_id, tensor_id, rank + tensor_id) for tensor_id in range( + tensor_id, + tensor_id, + rank + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] output_tensor_lists = [ [ _build_multidim_tensor( - tensor_id, tensor_id, -1) for tensor_id in range( + tensor_id, tensor_id, -1, dtype=dtype) for tensor_id in range( 1, test_case_id) ] for _ in group ] @@ -1601,7 +1656,8 @@ def _test_all_gather_coalesced_helper( _build_multidim_tensor( tensor_id, tensor_id, - rank_iter + tensor_id) for tensor_id in range( + rank_iter + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] for rank_iter in group ] @@ -1618,6 +1674,12 @@ def test_all_gather_coalesced_simple(self): group, group_id, rank = self._init_global_test() self._test_all_gather_coalesced_helper(group, group_id, rank) + @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") + @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") + def test_all_gather_coalesced_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_coalesced_helper(group, group_id, rank, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") @@ -1991,21 +2053,16 @@ def _test_all_reduce_multigpu_helper( master_value, worker_value, expected_value, + dtype=torch.float, ): for src in group: - if rank == src: - tensors = [ - _build_tensor(src + 1, master_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - else: - tensors = [ - _build_tensor(src + 1, worker_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - + curr_value = master_value if rank == src else worker_value + tensors = [ + _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i) + for i in rank_to_GPU[rank] + ] dist.all_reduce_multigpu(tensors, op, group_id) - expected_tensor = _build_tensor(src + 1, expected_value) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) for tensor in tensors: self.assertEqual(tensor, expected_tensor) @@ -2028,6 +2085,24 @@ def test_all_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) + @unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") + @unittest.skipIf(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") + @skip_if_no_gpu + def test_all_reduce_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_multigpu_helper( + group, + group_id, + rank, + rank_to_GPU, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + (complex(2, 3) + complex(10, 11) * (len(group) - 1)) * len(rank_to_GPU[0]), + dtype=torch.cfloat, + ) + def _test_reduce_multigpu_helper( self, group, @@ -2074,10 +2149,10 @@ def test_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) - def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): + def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU, dtype=torch.float): for dest in group: tensors = [ - _build_tensor(dest + 1).cuda(device=i) for i in rank_to_GPU[rank] + _build_tensor(dest + 1, dtype=dtype).cuda(device=i) for i in rank_to_GPU[rank] ] # construct expected output along with @@ -2085,10 +2160,10 @@ def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): output_tensors = [] expected_output = [] output_per_gpu = ( - [_build_tensor(dest + 1, -1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, -1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) expected_per_gpu = ( - [_build_tensor(dest + 1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) for gpu in rank_to_GPU[rank]: output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) @@ -2106,6 +2181,13 @@ def test_all_gather_multigpu(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU) + @unittest.skipIf(BACKEND != "nccl", "Only Nccl backend supports allgather multigpu") + @skip_if_no_gpu + def test_all_gather_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU, dtype=torch.cfloat) + def _model_step(self, model): for param in model.parameters(): if param.grad is not None: