diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index d15ea9d23412..e552f0e2a386 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -220,6 +220,8 @@ An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, Note that ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when using the ``NCCL`` backend. +Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. + The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. They are used in specifying strategies for reduction collectives, e.g., :func:`reduce`, :func:`all_reduce_multigpu`, etc.)") diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index d0ccd08b22f6..e4f08b6d697b 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -44,6 +44,17 @@ except ImportError: _GLOO_AVAILABLE = False +# Some reduce ops are not supported by complex numbers. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (torch.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT] + return reduceOp not in denyList + class Backend(object): """ @@ -970,6 +981,8 @@ def all_reduce_multigpu(tensor_list, After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise identical in all processes. + Complex tensors are supported. + Only nccl and gloo backend is currently supported tensors should only be GPU tensors @@ -993,6 +1006,8 @@ def all_reduce_multigpu(tensor_list, if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1017,6 +1032,8 @@ def all_reduce(tensor, After the call ``tensor`` is going to be bitwise identical in all processes. + Complex tensors are supported. + Arguments: tensor (Tensor): Input and output of the collective. The function operates in-place. @@ -1030,11 +1047,39 @@ def all_reduce(tensor, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Examples: + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6]) # Rank 0 + tensor([4, 6]) # Rank 1 + + >>> # All tensors below are of torch.cfloat type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j]) # Rank 0 + tensor([4.+4.j, 6.+6.j]) # Rank 1 + """ _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): return + if tensor.is_complex(): + if not supports_complex(op): + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") + tensor = torch.view_as_real(tensor) + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1068,6 +1113,8 @@ def all_reduce_coalesced(tensors, After the call each tensor in tensors is going to bitwise identical in all processes. + Complex tensors are supported. + Arguments: tensors (List[Tensor]): Input and output of the collective. The function operates in-place. @@ -1086,6 +1133,11 @@ def all_reduce_coalesced(tensors, if _rank_not_in_group(group): return + if any([t.is_complex() for t in tensors]) and not supports_complex(op): + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") + + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + opts = AllreduceCoalescedOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1215,6 +1267,8 @@ def all_gather_multigpu(output_tensor_lists, Only nccl backend is currently supported tensors should only be GPU tensors + Complex tensors are supported. + Arguments: output_tensor_lists (List[List[Tensor]]): Output lists. It should contain correctly-sized tensors on each GPU to be used for output @@ -1250,6 +1304,9 @@ def all_gather_multigpu(output_tensor_lists, if _rank_not_in_group(group): return + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather(output_tensor_lists, input_tensor_list) @@ -1498,6 +1555,8 @@ def all_gather(tensor_list, """ Gathers tensors from the whole group in a list. + Complex tensors are supported. + Arguments: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. @@ -1509,12 +1568,44 @@ def all_gather(tensor_list, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Examples: + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)] + >>> tensor_list + [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2]), tensor([3, 4])] # Rank 0 + [tensor([1, 2]), tensor([3, 4])] # Rank 1 + + >>> # All tensors below are of torch.cfloat dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0 + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1 + """ _check_tensor_list(tensor_list, "tensor_list") _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather([tensor_list], [tensor]) @@ -1533,6 +1624,8 @@ def all_gather_coalesced(output_tensor_lists, """ Gathers input tensors from the whole group in a list in a coalesced manner. + Complex tensors are supported. + Arguments: output_tensor_lists (list[list[Tensor]]): Output list. It should contain correctly-sized tensors to be used for output of the collective. @@ -1581,6 +1674,9 @@ def all_gather_coalesced(output_tensor_lists, for output_tensor_list in output_tensor_lists: _check_tensor_list(output_tensor_list, "output_tensor_lists") + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather_coalesced( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 35ce23d16742..f1c0e2a3a4a8 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -220,10 +220,10 @@ def _build_tensor(size, value=None, dtype=torch.float, device_id=None): return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id) -def _build_multidim_tensor(dim, dim_size, value=None): +def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float): if value is None: value = size - return torch.FloatTensor(size=[dim_size for _ in range(dim)]).fill_(value) + return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value) class Barrier(object): @@ -1241,20 +1241,17 @@ def _test_all_reduce_helper( expected_value, cuda=False, rank_to_GPU=None, + dtype=torch.float, ): for src in group: - if rank == src: - tensor = _build_tensor(src + 1).fill_(master_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) - else: - tensor = _build_tensor(src + 1).fill_(worker_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + curr_value = master_value if rank == src else worker_value + + tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + dist.all_reduce(tensor, op, group_id) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) + self.assertEqual(tensor, expected_tensor) self._barrier() @@ -1291,6 +1288,47 @@ def test_all_reduce_sum_cuda(self): rank_to_GPU, ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_sum_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + dtype=torch.cfloat, + ) + + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_max_complex_unsupported(self): + group, group_id, rank = self._init_global_test() + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): + dist.all_reduce(_build_tensor(1, dtype=torch.cfloat), dist.ReduceOp.MAX, group_id) + + @unittest.skipIf( + BACKEND != "gloo", + "Only Gloo backend will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + True, + rank_to_GPU, + dtype=torch.cfloat, + ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_product(self): group, group_id, rank = self._init_global_test() @@ -1429,9 +1467,10 @@ def test_sparse_all_reduce_sum_cuda(self): @staticmethod def _all_reduce_coalesced_sum_test_cases(group_size): return ( - [2, 3], - [10, 11], - [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1)] + [2, 3, complex(2, 3)], + [10, 11, complex(10, 11)], + [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1), complex(2, 3) + complex(10, 11) * (group_size - 1)], + [torch.float, torch.float, torch.cfloat], ) @staticmethod @@ -1439,7 +1478,8 @@ def _all_reduce_coalesced_product_test_cases(group_size): return ( [1, 2], [3, 4], - [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)] + [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)], + [torch.float, torch.float], ) @staticmethod @@ -1447,7 +1487,8 @@ def _all_reduce_coalesced_min_test_cases(group_size): return ( [1, 4], [2, 3], - [1, 3] + [1, 3], + [torch.float, torch.float], ) @staticmethod @@ -1455,9 +1496,16 @@ def _all_reduce_coalesced_max_test_cases(group_size): return ( [1, 4], [2, 3], - [2, 4] + [2, 4], + [torch.float, torch.float], ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_coalesced_max_complex_unsupported(self): + group, group_id, rank = self._init_global_test() + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): + dist.all_reduce_coalesced([_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id) + def _test_all_reduce_coalesced_helper( self, group, @@ -1474,22 +1522,24 @@ def _test_all_reduce_coalesced_helper( dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases }[op] - master_values, worker_values, expected_values = test_case_func(len(group)) + master_values, worker_values, expected_values, dtypes = test_case_func(len(group)) for src in group: + curr_values = master_values if rank == src else worker_values tensors = [ - _build_tensor(src + 1, val) - for val in (master_values if rank == src else worker_values) + _build_tensor(src + 1, val, dtype=dtype) + for dtype, val in zip(dtypes, curr_values) ] if cuda: tensors = list(map(tensors, lambda t: t.cuda(rank_to_GPU[rank][0]))) dist.all_reduce_coalesced(tensors, op, group_id) + expected_tensors = [ + _build_tensor(src + 1, expected_value, dtype=dtype) + for dtype, expected_value in zip(dtypes, expected_values) + ] self.assertEqual( tensors, - [ - _build_tensor(src + 1, expected_value) - for expected_value in expected_values - ] + expected_tensors ) self._barrier() @@ -1750,17 +1800,18 @@ def test_gather_full_group(self): # ALL GATHER def _test_all_gather_helper( - self, group, group_id, rank, cuda=False, rank_to_GPU=None + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float ): + for dest in group: - tensor = _build_tensor(dest + 1, rank) - tensors = [_build_tensor(dest + 1, -1) for i in group] + tensor = _build_tensor(dest + 1, rank, dtype=dtype) + tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group] if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] dist.all_gather(tensors, tensor, group_id) - expected_tensors = [_build_tensor(dest + 1, i) for i in group] + expected_tensors = [_build_tensor(dest + 1, i, dtype=dtype) for i in group] for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2) @@ -1779,6 +1830,19 @@ def test_all_gather_cuda(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_gather_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat) + + @unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA all gather") + @unittest.skipIf(BACKEND == "nccl", "CUDA all gather skipped for NCCL") + @skip_if_no_gpu + def test_all_gather_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_gather_group(self): @@ -1807,7 +1871,7 @@ def _run_all_gather_coalesced_and_verify( return True def _test_all_gather_coalesced_helper( - self, group, group_id, rank + self, group, group_id, rank, dtype=torch.float ): # TODO: Instead we should probably go through _rank_not_in_group # mechanism to disable sending tensors @@ -1817,13 +1881,16 @@ def _test_all_gather_coalesced_helper( # [1], [2x2], [3x3x3] ... to be sent in one batch input_tensors = [ _build_multidim_tensor( - tensor_id, tensor_id, rank + tensor_id) for tensor_id in range( + tensor_id, + tensor_id, + rank + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] output_tensor_lists = [ [ _build_multidim_tensor( - tensor_id, tensor_id, -1) for tensor_id in range( + tensor_id, tensor_id, -1, dtype=dtype) for tensor_id in range( 1, test_case_id) ] for _ in group ] @@ -1832,7 +1899,8 @@ def _test_all_gather_coalesced_helper( _build_multidim_tensor( tensor_id, tensor_id, - rank_iter + tensor_id) for tensor_id in range( + rank_iter + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] for rank_iter in group ] @@ -1849,6 +1917,12 @@ def test_all_gather_coalesced_simple(self): group, group_id, rank = self._init_global_test() self._test_all_gather_coalesced_helper(group, group_id, rank) + @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") + @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") + def test_all_gather_coalesced_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_coalesced_helper(group, group_id, rank, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") @@ -2222,21 +2296,16 @@ def _test_all_reduce_multigpu_helper( master_value, worker_value, expected_value, + dtype=torch.float, ): for src in group: - if rank == src: - tensors = [ - _build_tensor(src + 1, master_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - else: - tensors = [ - _build_tensor(src + 1, worker_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - + curr_value = master_value if rank == src else worker_value + tensors = [ + _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i) + for i in rank_to_GPU[rank] + ] dist.all_reduce_multigpu(tensors, op, group_id) - expected_tensor = _build_tensor(src + 1, expected_value) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) for tensor in tensors: self.assertEqual(tensor, expected_tensor) @@ -2259,6 +2328,24 @@ def test_all_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) + @unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") + @unittest.skipIf(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") + @skip_if_no_gpu + def test_all_reduce_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_multigpu_helper( + group, + group_id, + rank, + rank_to_GPU, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + (complex(2, 3) + complex(10, 11) * (len(group) - 1)) * len(rank_to_GPU[0]), + dtype=torch.cfloat, + ) + def _test_reduce_multigpu_helper( self, group, @@ -2305,10 +2392,10 @@ def test_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) - def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): + def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU, dtype=torch.float): for dest in group: tensors = [ - _build_tensor(dest + 1).cuda(device=i) for i in rank_to_GPU[rank] + _build_tensor(dest + 1, dtype=dtype).cuda(device=i) for i in rank_to_GPU[rank] ] # construct expected output along with @@ -2316,10 +2403,10 @@ def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): output_tensors = [] expected_output = [] output_per_gpu = ( - [_build_tensor(dest + 1, -1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, -1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) expected_per_gpu = ( - [_build_tensor(dest + 1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) for gpu in rank_to_GPU[rank]: output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) @@ -2337,6 +2424,13 @@ def test_all_gather_multigpu(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU) + @unittest.skipIf(BACKEND != "nccl", "Only Nccl backend supports allgather multigpu") + @skip_if_no_gpu + def test_all_gather_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU, dtype=torch.cfloat) + def _model_step(self, model): for param in model.parameters(): if param.grad is not None: