From 22f442bdbdfeb0d8ac992bfd9209da4f06afbad9 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 9 Oct 2020 07:23:18 -0700 Subject: [PATCH] adding complex support for distributed functions and . fix #45760 updated docs used standard python repl examples in the docs, tested the way that they render in the browser more doc fixes. Add an explicit error check for ReduceOps that do not support complex (Max and Min), + tests for that case make error checking a bit more pythonic ghstack-source-id: 57babd5380cf8eb464b66114a6532011b9aea4ac Pull Request resolved: https://github.com/pytorch/pytorch/pull/45879 --- torch/csrc/distributed/c10d/init.cpp | 2 + torch/distributed/distributed_c10d.py | 96 +++++++++ .../_internal/distributed/distributed_test.py | 196 +++++++++++++----- 3 files changed, 243 insertions(+), 51 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index d15ea9d23412..e552f0e2a386 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -220,6 +220,8 @@ An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, Note that ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when using the ``NCCL`` backend. +Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. + The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. They are used in specifying strategies for reduction collectives, e.g., :func:`reduce`, :func:`all_reduce_multigpu`, etc.)") diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index a125d8a1204b..1269d57721fd 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -44,6 +44,17 @@ except ImportError: _GLOO_AVAILABLE = False +# Some reduce ops are not supported by complex numbers. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (torch.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT] + return reduceOp not in denyList + class Backend(object): """ @@ -970,6 +981,8 @@ def all_reduce_multigpu(tensor_list, After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise identical in all processes. + Complex tensors are supported. + Only nccl and gloo backend is currently supported tensors should only be GPU tensors @@ -993,6 +1006,8 @@ def all_reduce_multigpu(tensor_list, if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1017,6 +1032,8 @@ def all_reduce(tensor, After the call ``tensor`` is going to be bitwise identical in all processes. + Complex tensors are supported. + Arguments: tensor (Tensor): Input and output of the collective. The function operates in-place. @@ -1030,11 +1047,39 @@ def all_reduce(tensor, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Examples: + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6]) # Rank 0 + tensor([4, 6]) # Rank 1 + + >>> # All tensors below are of torch.cfloat type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j]) # Rank 0 + tensor([4.+4.j, 6.+6.j]) # Rank 1 + """ _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): return + if tensor.is_complex(): + if not supports_complex(op): + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") + tensor = torch.view_as_real(tensor) + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1068,6 +1113,8 @@ def all_reduce_coalesced(tensors, After the call each tensor in tensors is going to bitwise identical in all processes. + Complex tensors are supported. + Arguments: tensors (List[Tensor]): Input and output of the collective. The function operates in-place. @@ -1086,6 +1133,11 @@ def all_reduce_coalesced(tensors, if _rank_not_in_group(group): return + if any([t.is_complex() for t in tensors]) and not supports_complex(op): + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") + + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + opts = AllreduceCoalescedOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1215,6 +1267,8 @@ def all_gather_multigpu(output_tensor_lists, Only nccl backend is currently supported tensors should only be GPU tensors + Complex tensors are supported. + Arguments: output_tensor_lists (List[List[Tensor]]): Output lists. It should contain correctly-sized tensors on each GPU to be used for output @@ -1250,6 +1304,9 @@ def all_gather_multigpu(output_tensor_lists, if _rank_not_in_group(group): return + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather(output_tensor_lists, input_tensor_list) @@ -1498,6 +1555,8 @@ def all_gather(tensor_list, """ Gathers tensors from the whole group in a list. + Complex tensors are supported. + Arguments: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. @@ -1509,12 +1568,44 @@ def all_gather(tensor_list, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Examples: + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)] + >>> tensor_list + [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2]), tensor([3, 4])] # Rank 0 + [tensor([1, 2]), tensor([3, 4])] # Rank 1 + + >>> # All tensors below are of torch.cfloat dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0 + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1 + """ _check_tensor_list(tensor_list, "tensor_list") _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather([tensor_list], [tensor]) @@ -1533,6 +1624,8 @@ def all_gather_coalesced(output_tensor_lists, """ Gathers input tensors from the whole group in a list in a coalesced manner. + Complex tensors are supported. + Arguments: output_tensor_lists (list[list[Tensor]]): Output list. It should contain correctly-sized tensors to be used for output of the collective. @@ -1581,6 +1674,9 @@ def all_gather_coalesced(output_tensor_lists, for output_tensor_list in output_tensor_lists: _check_tensor_list(output_tensor_list, "output_tensor_lists") + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather_coalesced( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index facae489fbd0..2cbf0c81f24c 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -220,10 +220,10 @@ def _build_tensor(size, value=None, dtype=torch.float, device_id=None): return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id) -def _build_multidim_tensor(dim, dim_size, value=None): +def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float): if value is None: value = size - return torch.FloatTensor(size=[dim_size for _ in range(dim)]).fill_(value) + return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value) class Barrier(object): @@ -1247,20 +1247,17 @@ def _test_all_reduce_helper( expected_value, cuda=False, rank_to_GPU=None, + dtype=torch.float, ): for src in group: - if rank == src: - tensor = _build_tensor(src + 1).fill_(master_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) - else: - tensor = _build_tensor(src + 1).fill_(worker_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + curr_value = master_value if rank == src else worker_value + + tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + dist.all_reduce(tensor, op, group_id) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) + self.assertEqual(tensor, expected_tensor) self._barrier() @@ -1297,6 +1294,47 @@ def test_all_reduce_sum_cuda(self): rank_to_GPU, ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_sum_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + dtype=torch.cfloat, + ) + + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_max_complex_unsupported(self): + group, group_id, rank = self._init_global_test() + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): + dist.all_reduce(_build_tensor(1, dtype=torch.cfloat), dist.ReduceOp.MAX, group_id) + + @unittest.skipIf( + BACKEND != "gloo", + "Only Gloo backend will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + True, + rank_to_GPU, + dtype=torch.cfloat, + ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_product(self): group, group_id, rank = self._init_global_test() @@ -1435,9 +1473,10 @@ def test_sparse_all_reduce_sum_cuda(self): @staticmethod def _all_reduce_coalesced_sum_test_cases(group_size): return ( - [2, 3], - [10, 11], - [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1)] + [2, 3, complex(2, 3)], + [10, 11, complex(10, 11)], + [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1), complex(2, 3) + complex(10, 11) * (group_size - 1)], + [torch.float, torch.float, torch.cfloat], ) @staticmethod @@ -1445,7 +1484,8 @@ def _all_reduce_coalesced_product_test_cases(group_size): return ( [1, 2], [3, 4], - [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)] + [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)], + [torch.float, torch.float], ) @staticmethod @@ -1453,7 +1493,8 @@ def _all_reduce_coalesced_min_test_cases(group_size): return ( [1, 4], [2, 3], - [1, 3] + [1, 3], + [torch.float, torch.float], ) @staticmethod @@ -1461,9 +1502,16 @@ def _all_reduce_coalesced_max_test_cases(group_size): return ( [1, 4], [2, 3], - [2, 4] + [2, 4], + [torch.float, torch.float], ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_coalesced_max_complex_unsupported(self): + group, group_id, rank = self._init_global_test() + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): + dist.all_reduce_coalesced([_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id) + def _test_all_reduce_coalesced_helper( self, group, @@ -1480,22 +1528,24 @@ def _test_all_reduce_coalesced_helper( dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases }[op] - master_values, worker_values, expected_values = test_case_func(len(group)) + master_values, worker_values, expected_values, dtypes = test_case_func(len(group)) for src in group: + curr_values = master_values if rank == src else worker_values tensors = [ - _build_tensor(src + 1, val) - for val in (master_values if rank == src else worker_values) + _build_tensor(src + 1, val, dtype=dtype) + for dtype, val in zip(dtypes, curr_values) ] if cuda: tensors = list(map(tensors, lambda t: t.cuda(rank_to_GPU[rank][0]))) dist.all_reduce_coalesced(tensors, op, group_id) + expected_tensors = [ + _build_tensor(src + 1, expected_value, dtype=dtype) + for dtype, expected_value in zip(dtypes, expected_values) + ] self.assertEqual( tensors, - [ - _build_tensor(src + 1, expected_value) - for expected_value in expected_values - ] + expected_tensors ) self._barrier() @@ -1756,17 +1806,18 @@ def test_gather_full_group(self): # ALL GATHER def _test_all_gather_helper( - self, group, group_id, rank, cuda=False, rank_to_GPU=None + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float ): + for dest in group: - tensor = _build_tensor(dest + 1, rank) - tensors = [_build_tensor(dest + 1, -1) for i in group] + tensor = _build_tensor(dest + 1, rank, dtype=dtype) + tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group] if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] dist.all_gather(tensors, tensor, group_id) - expected_tensors = [_build_tensor(dest + 1, i) for i in group] + expected_tensors = [_build_tensor(dest + 1, i, dtype=dtype) for i in group] for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2) @@ -1785,6 +1836,19 @@ def test_all_gather_cuda(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_gather_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat) + + @unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA all gather") + @unittest.skipIf(BACKEND == "nccl", "CUDA all gather skipped for NCCL") + @skip_if_no_gpu + def test_all_gather_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_gather_group(self): @@ -1813,7 +1877,7 @@ def _run_all_gather_coalesced_and_verify( return True def _test_all_gather_coalesced_helper( - self, group, group_id, rank + self, group, group_id, rank, dtype=torch.float ): # TODO: Instead we should probably go through _rank_not_in_group # mechanism to disable sending tensors @@ -1823,13 +1887,16 @@ def _test_all_gather_coalesced_helper( # [1], [2x2], [3x3x3] ... to be sent in one batch input_tensors = [ _build_multidim_tensor( - tensor_id, tensor_id, rank + tensor_id) for tensor_id in range( + tensor_id, + tensor_id, + rank + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] output_tensor_lists = [ [ _build_multidim_tensor( - tensor_id, tensor_id, -1) for tensor_id in range( + tensor_id, tensor_id, -1, dtype=dtype) for tensor_id in range( 1, test_case_id) ] for _ in group ] @@ -1838,7 +1905,8 @@ def _test_all_gather_coalesced_helper( _build_multidim_tensor( tensor_id, tensor_id, - rank_iter + tensor_id) for tensor_id in range( + rank_iter + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] for rank_iter in group ] @@ -1855,6 +1923,12 @@ def test_all_gather_coalesced_simple(self): group, group_id, rank = self._init_global_test() self._test_all_gather_coalesced_helper(group, group_id, rank) + @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") + @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") + def test_all_gather_coalesced_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_coalesced_helper(group, group_id, rank, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") @@ -2228,21 +2302,16 @@ def _test_all_reduce_multigpu_helper( master_value, worker_value, expected_value, + dtype=torch.float, ): for src in group: - if rank == src: - tensors = [ - _build_tensor(src + 1, master_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - else: - tensors = [ - _build_tensor(src + 1, worker_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - + curr_value = master_value if rank == src else worker_value + tensors = [ + _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i) + for i in rank_to_GPU[rank] + ] dist.all_reduce_multigpu(tensors, op, group_id) - expected_tensor = _build_tensor(src + 1, expected_value) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) for tensor in tensors: self.assertEqual(tensor, expected_tensor) @@ -2265,6 +2334,24 @@ def test_all_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) + @unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") + @unittest.skipIf(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") + @skip_if_no_gpu + def test_all_reduce_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_multigpu_helper( + group, + group_id, + rank, + rank_to_GPU, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + (complex(2, 3) + complex(10, 11) * (len(group) - 1)) * len(rank_to_GPU[0]), + dtype=torch.cfloat, + ) + def _test_reduce_multigpu_helper( self, group, @@ -2311,10 +2398,10 @@ def test_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) - def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): + def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU, dtype=torch.float): for dest in group: tensors = [ - _build_tensor(dest + 1).cuda(device=i) for i in rank_to_GPU[rank] + _build_tensor(dest + 1, dtype=dtype).cuda(device=i) for i in rank_to_GPU[rank] ] # construct expected output along with @@ -2322,10 +2409,10 @@ def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): output_tensors = [] expected_output = [] output_per_gpu = ( - [_build_tensor(dest + 1, -1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, -1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) expected_per_gpu = ( - [_build_tensor(dest + 1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) for gpu in rank_to_GPU[rank]: output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) @@ -2343,6 +2430,13 @@ def test_all_gather_multigpu(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU) + @unittest.skipIf(BACKEND != "nccl", "Only Nccl backend supports allgather multigpu") + @skip_if_no_gpu + def test_all_gather_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU, dtype=torch.cfloat) + def _model_step(self, model): for param in model.parameters(): if param.grad is not None: