From 0c156987d87b47ebafe39e81f320bd04006e8ff8 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 5 Oct 2020 17:09:24 -0700 Subject: [PATCH 01/12] adding complex support for distributed functions and . fix #45760 [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 15 ++ .../_internal/distributed/distributed_test.py | 184 +++++++++++++----- 2 files changed, 148 insertions(+), 51 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index ae4338cd28fc..e34fadcdca4d 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -892,6 +892,8 @@ def all_reduce_multigpu(tensor_list, if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -934,6 +936,8 @@ def all_reduce(tensor, if _rank_not_in_group(group): return + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + opts = AllreduceOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -985,6 +989,8 @@ def all_reduce_coalesced(tensors, if _rank_not_in_group(group): return + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + opts = AllreduceCoalescedOptions() opts.reduceOp = op if group == GroupMember.WORLD: @@ -1149,6 +1155,9 @@ def all_gather_multigpu(output_tensor_lists, if _rank_not_in_group(group): return + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather(output_tensor_lists, input_tensor_list) @@ -1414,6 +1423,9 @@ def all_gather(tensor_list, if _rank_not_in_group(group): return + tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather([tensor_list], [tensor]) @@ -1480,6 +1492,9 @@ def all_gather_coalesced(output_tensor_lists, for output_tensor_list in output_tensor_lists: _check_tensor_list(output_tensor_list, "output_tensor_lists") + output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] + input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] + if group == GroupMember.WORLD: _check_default_pg() work = _default_pg.allgather_coalesced( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 01cddee92365..de011763487e 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -216,10 +216,10 @@ def _build_tensor(size, value=None, dtype=torch.float): return torch.empty(size, size, size, dtype=dtype).fill_(value) -def _build_multidim_tensor(dim, dim_size, value=None): +def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float): if value is None: value = size - return torch.FloatTensor(size=[dim_size for _ in range(dim)]).fill_(value) + return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value) class Barrier(object): @@ -1010,20 +1010,17 @@ def _test_all_reduce_helper( expected_value, cuda=False, rank_to_GPU=None, + dtype=torch.float, ): for src in group: - if rank == src: - tensor = _build_tensor(src + 1).fill_(master_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) - else: - tensor = _build_tensor(src + 1).fill_(worker_value) - if cuda: - tensor = tensor.cuda(rank_to_GPU[rank][0]) - dist.all_reduce(tensor, op, group_id) - self.assertEqual(tensor, _build_tensor(src + 1, expected_value)) + curr_value = master_value if rank == src else worker_value + + tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value) + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + dist.all_reduce(tensor, op, group_id) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) + self.assertEqual(tensor, expected_tensor) self._barrier() @@ -1060,6 +1057,41 @@ def test_all_reduce_sum_cuda(self): rank_to_GPU, ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_sum_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + dtype=torch.cfloat, + ) + + @unittest.skipIf( + BACKEND != "gloo", + "Only Gloo backend will have CUDA allReduce tested", + ) + @skip_if_no_gpu + def test_all_reduce_sum_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_helper( + group, + group_id, + rank, + dist.ReduceOp.SUM, + complex(2, 3), + complex(10, 11), + complex(2, 3) + (complex(10, 11) * (len(group) - 1)), + True, + rank_to_GPU, + dtype=torch.cfloat, + ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_product(self): group, group_id, rank = self._init_global_test() @@ -1198,9 +1230,10 @@ def test_sparse_all_reduce_sum_cuda(self): @staticmethod def _all_reduce_coalesced_sum_test_cases(group_size): return ( - [2, 3], - [10, 11], - [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1)] + [2, 3, complex(2, 3)], + [10, 11, complex(10, 11)], + [2 + 10 * (group_size - 1), 3 + 11 * (group_size - 1), complex(2, 3) + complex(10, 11) * (group_size - 1)], + [torch.float, torch.float, torch.cfloat], ) @staticmethod @@ -1208,7 +1241,8 @@ def _all_reduce_coalesced_product_test_cases(group_size): return ( [1, 2], [3, 4], - [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)] + [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)], + [torch.float, torch.float], ) @staticmethod @@ -1216,7 +1250,8 @@ def _all_reduce_coalesced_min_test_cases(group_size): return ( [1, 4], [2, 3], - [1, 3] + [1, 3], + [torch.float, torch.float], ) @staticmethod @@ -1224,7 +1259,8 @@ def _all_reduce_coalesced_max_test_cases(group_size): return ( [1, 4], [2, 3], - [2, 4] + [2, 4], + [torch.float, torch.float], ) def _test_all_reduce_coalesced_helper( @@ -1243,22 +1279,24 @@ def _test_all_reduce_coalesced_helper( dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases }[op] - master_values, worker_values, expected_values = test_case_func(len(group)) + master_values, worker_values, expected_values, dtypes = test_case_func(len(group)) for src in group: + curr_values = master_values if rank == src else worker_values tensors = [ - _build_tensor(src + 1, val) - for val in (master_values if rank == src else worker_values) + _build_tensor(src + 1, val, dtype=dtype) + for dtype, val in zip(dtypes, curr_values) ] if cuda: tensors = list(map(tensors, lambda t: t.cuda(rank_to_GPU[rank][0]))) dist.all_reduce_coalesced(tensors, op, group_id) + expected_tensors = [ + _build_tensor(src + 1, expected_value, dtype=dtype) + for dtype, expected_value in zip(dtypes, expected_values) + ] self.assertEqual( tensors, - [ - _build_tensor(src + 1, expected_value) - for expected_value in expected_values - ] + expected_tensors ) self._barrier() @@ -1519,17 +1557,18 @@ def test_gather_full_group(self): # ALL GATHER def _test_all_gather_helper( - self, group, group_id, rank, cuda=False, rank_to_GPU=None + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float ): + for dest in group: - tensor = _build_tensor(dest + 1, rank) - tensors = [_build_tensor(dest + 1, -1) for i in group] + tensor = _build_tensor(dest + 1, rank, dtype=dtype) + tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group] if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] dist.all_gather(tensors, tensor, group_id) - expected_tensors = [_build_tensor(dest + 1, i) for i in group] + expected_tensors = [_build_tensor(dest + 1, i, dtype=dtype) for i in group] for t1, t2 in zip(tensors, expected_tensors): self.assertEqual(t1, t2) @@ -1548,6 +1587,19 @@ def test_all_gather_cuda(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_gather_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat) + + @unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA all gather") + @unittest.skipIf(BACKEND == "nccl", "CUDA all gather skipped for NCCL") + @skip_if_no_gpu + def test_all_gather_cuda_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_gather_group(self): @@ -1576,7 +1628,7 @@ def _run_all_gather_coalesced_and_verify( return True def _test_all_gather_coalesced_helper( - self, group, group_id, rank + self, group, group_id, rank, dtype=torch.float ): # TODO: Instead we should probably go through _rank_not_in_group # mechanism to disable sending tensors @@ -1586,13 +1638,16 @@ def _test_all_gather_coalesced_helper( # [1], [2x2], [3x3x3] ... to be sent in one batch input_tensors = [ _build_multidim_tensor( - tensor_id, tensor_id, rank + tensor_id) for tensor_id in range( + tensor_id, + tensor_id, + rank + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] output_tensor_lists = [ [ _build_multidim_tensor( - tensor_id, tensor_id, -1) for tensor_id in range( + tensor_id, tensor_id, -1, dtype=dtype) for tensor_id in range( 1, test_case_id) ] for _ in group ] @@ -1601,7 +1656,8 @@ def _test_all_gather_coalesced_helper( _build_multidim_tensor( tensor_id, tensor_id, - rank_iter + tensor_id) for tensor_id in range( + rank_iter + tensor_id, + dtype=dtype) for tensor_id in range( 1, test_case_id) ] for rank_iter in group ] @@ -1618,6 +1674,12 @@ def test_all_gather_coalesced_simple(self): group, group_id, rank = self._init_global_test() self._test_all_gather_coalesced_helper(group, group_id, rank) + @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") + @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") + def test_all_gather_coalesced_complex(self): + group, group_id, rank = self._init_global_test() + self._test_all_gather_coalesced_helper(group, group_id, rank, dtype=torch.cfloat) + @skip_if_small_worldsize @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL") @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI") @@ -1991,21 +2053,16 @@ def _test_all_reduce_multigpu_helper( master_value, worker_value, expected_value, + dtype=torch.float, ): for src in group: - if rank == src: - tensors = [ - _build_tensor(src + 1, master_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - else: - tensors = [ - _build_tensor(src + 1, worker_value).cuda(device=i) - for i in rank_to_GPU[rank] - ] - + curr_value = master_value if rank == src else worker_value + tensors = [ + _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i) + for i in rank_to_GPU[rank] + ] dist.all_reduce_multigpu(tensors, op, group_id) - expected_tensor = _build_tensor(src + 1, expected_value) + expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype) for tensor in tensors: self.assertEqual(tensor, expected_tensor) @@ -2028,6 +2085,24 @@ def test_all_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) + @unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") + @unittest.skipIf(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") + @skip_if_no_gpu + def test_all_reduce_multigpu(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_reduce_multigpu_helper( + group, + group_id, + rank, + rank_to_GPU, + dist.ReduceOp.SUM, + complex(2,3), + complex(10,11), + (complex(2,3) + complex(10,11) * (len(group) - 1)) * len(rank_to_GPU[0]), + dtype=torch.cfloat, + ) + def _test_reduce_multigpu_helper( self, group, @@ -2074,10 +2149,10 @@ def test_reduce_multigpu(self): (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]), ) - def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): + def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU, dtype=torch.float): for dest in group: tensors = [ - _build_tensor(dest + 1).cuda(device=i) for i in rank_to_GPU[rank] + _build_tensor(dest + 1, dtype=dtype).cuda(device=i) for i in rank_to_GPU[rank] ] # construct expected output along with @@ -2085,10 +2160,10 @@ def _test_all_gather_multigpu_helper(self, group, group_id, rank, rank_to_GPU): output_tensors = [] expected_output = [] output_per_gpu = ( - [_build_tensor(dest + 1, -1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, -1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) expected_per_gpu = ( - [_build_tensor(dest + 1)] * len(rank_to_GPU[0]) * len(group) + [_build_tensor(dest + 1, dtype=dtype)] * len(rank_to_GPU[0]) * len(group) ) for gpu in rank_to_GPU[rank]: output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu]) @@ -2106,6 +2181,13 @@ def test_all_gather_multigpu(self): rank_to_GPU = self._init_multigpu_helper() self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU) + @unittest.skipIf(BACKEND != "nccl", "Only Nccl backend supports allgather multigpu") + @skip_if_no_gpu + def test_all_gather_multigpu_complex(self): + group, group_id, rank = self._init_global_test() + rank_to_GPU = self._init_multigpu_helper() + self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU, dtype=torch.cfloat) + def _model_step(self, model): for param in model.parameters(): if param.grad is not None: From d590fe660e4233fa201afbe8ffa12fb73246356c Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 5 Oct 2020 17:16:04 -0700 Subject: [PATCH 02/12] Update on "adding complex support for distributed functions and . fix #45760" [ghstack-poisoned] --- .../_internal/distributed/distributed_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index de011763487e..57e5b8367556 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1291,8 +1291,8 @@ def _test_all_reduce_coalesced_helper( tensors = list(map(tensors, lambda t: t.cuda(rank_to_GPU[rank][0]))) dist.all_reduce_coalesced(tensors, op, group_id) expected_tensors = [ - _build_tensor(src + 1, expected_value, dtype=dtype) - for dtype, expected_value in zip(dtypes, expected_values) + _build_tensor(src + 1, expected_value, dtype=dtype) + for dtype, expected_value in zip(dtypes, expected_values) ] self.assertEqual( tensors, @@ -2088,7 +2088,7 @@ def test_all_reduce_multigpu(self): @unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") @unittest.skipIf(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") @skip_if_no_gpu - def test_all_reduce_multigpu(self): + def test_all_reduce_multigpu_complex(self): group, group_id, rank = self._init_global_test() rank_to_GPU = self._init_multigpu_helper() self._test_all_reduce_multigpu_helper( @@ -2097,9 +2097,9 @@ def test_all_reduce_multigpu(self): rank, rank_to_GPU, dist.ReduceOp.SUM, - complex(2,3), - complex(10,11), - (complex(2,3) + complex(10,11) * (len(group) - 1)) * len(rank_to_GPU[0]), + complex(2, 3), + complex(10, 11), + (complex(2, 3) + complex(10, 11) * (len(group) - 1)) * len(rank_to_GPU[0]), dtype=torch.cfloat, ) From 517f458064a01700f736e736042760d923e48859 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 6 Oct 2020 07:38:16 -0700 Subject: [PATCH 03/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 66 +++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index e34fadcdca4d..c6886bd62c9c 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -869,6 +869,8 @@ def all_reduce_multigpu(tensor_list, After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise identical in all processes. + Complex tensors are supported. + Only nccl and gloo backend is currently supported tensors should only be GPU tensors @@ -918,6 +920,8 @@ def all_reduce(tensor, After the call ``tensor`` is going to be bitwise identical in all processes. + Complex tensors are supported. + Arguments: tensor (Tensor): Input and output of the collective. The function operates in-place. @@ -931,6 +935,25 @@ def all_reduce(tensor, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Example: + Tensors are all of dtype torch.int64. + We have 2 process groups, 2 ranks. + rank 0 passes: + tensor = [[1, 1], [2, 2]] + rank 1 passes: + tensor = [[3, 3], [4, 4]] + both rank 0 and 1 get: + tensor = [[4, 4], [6, 6]] + + Tensors are all of dtype torch.complex64. + We have 2 process groups, 2 ranks. + rank 0 passes: + tensor = [[1+i, 1+i], [2+2i, 2+2i]] + rank 1 passes: + tensor = [[3+3i, 3+3i], [4+4i, 4+4i]] + both rank 0 and 1 get: + tensor = [[4+4i, 4+4i], [6+6i, 6+6i]] + """ _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): @@ -971,6 +994,8 @@ def all_reduce_coalesced(tensors, After the call each tensor in tensors is going to bitwise identical in all processes. + Complex tensors are supported. + Arguments: tensors (List[Tensor]): Input and output of the collective. The function operates in-place. @@ -1120,6 +1145,8 @@ def all_gather_multigpu(output_tensor_lists, Only nccl backend is currently supported tensors should only be GPU tensors + Complex tensors are supported. + Arguments: output_tensor_lists (List[List[Tensor]]): Output lists. It should contain correctly-sized tensors on each GPU to be used for output @@ -1406,6 +1433,8 @@ def all_gather(tensor_list, """ Gathers tensors from the whole group in a list. + Complex tensors are supported. + Arguments: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. @@ -1417,6 +1446,41 @@ def all_gather(tensor_list, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group + Example: + Tensors are all of dtype torch.int64. + We have 2 process groups, 2 ranks. + rank 0 passes: + tensor_list = + [[[-1, -1], [-1, -1]], + [[-1, -1], [-1, -1]]] + tensor = [[1, 1], [2, 2]] + rank 1 passes: + tensor_list = + [[[-1, -1], [-1, -1]], + [[-1, -1], [-1, -1]]] + tensor = [[3, 3], [4, 4]] + both rank 0 and 1 get: + tensor_list = + [[[1, 1], [2, 2]], + [[3, 3], [4, 4]]] + + Tensors are all of dtype torch.complex64. + We have 2 process groups, 2 ranks. + rank 0 passes: + tensor_list = + [[[0+0i, 0+0i], [0+0i, 0+0i]], + [[0+0i, 0+0i], [0+0i, 0+0i]]] + tensor = [[1+i, 1+i], [2+2i, 2+2i]] + rank 1 passes: + tensor_list = + [[[0+0i, 0+0i], [0+0i, 0+0i]], + [[0+0i, 0+0i], [0+0i, 0+0i]]] + tensor = [[3+3i, 3+3i], [4+4i, 4+4i]] + both rank 0 and 1 get: + tensor_list = + [[[1+i, 1+i], [2+2i, 2+2i]], + [[3+3i, 3+3i], [4+4i, 4+4i]]] + """ _check_tensor_list(tensor_list, "tensor_list") _check_single_tensor(tensor, "tensor") @@ -1444,6 +1508,8 @@ def all_gather_coalesced(output_tensor_lists, """ Gathers input tensors from the whole group in a list in a coalesced manner. + Complex tensors are supported. + Arguments: output_tensor_lists (list[list[Tensor]]): Output list. It should contain correctly-sized tensors to be used for output of the collective. From 8c0578d22785d548dffb88ebb0dd28619ce28a61 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 7 Oct 2020 11:41:40 -0700 Subject: [PATCH 04/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 102 +++++++++++++------------- 1 file changed, 50 insertions(+), 52 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c6886bd62c9c..66ec944f1e05 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -935,24 +935,28 @@ def all_reduce(tensor, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - Example: - Tensors are all of dtype torch.int64. - We have 2 process groups, 2 ranks. - rank 0 passes: - tensor = [[1, 1], [2, 2]] - rank 1 passes: - tensor = [[3, 3], [4, 4]] - both rank 0 and 1 get: - tensor = [[4, 4], [6, 6]] - - Tensors are all of dtype torch.complex64. - We have 2 process groups, 2 ranks. - rank 0 passes: - tensor = [[1+i, 1+i], [2+2i, 2+2i]] - rank 1 passes: - tensor = [[3+3i, 3+3i], [4+4i, 4+4i]] - both rank 0 and 1 get: - tensor = [[4+4i, 4+4i], [6+6i, 6+6i]] + Examples: + >>> # Tensors are all of dtype torch.int64. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6]) # Rank 0 + tensor([4, 6]) # Rank 1 + + >>> # Tensors are all of dtype torch.complex64. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j]) # Rank 0 + tensor([4.+4.j, 6.+6.j]) # Rank 1 """ _check_single_tensor(tensor, "tensor") @@ -1446,40 +1450,34 @@ def all_gather(tensor_list, Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - Example: - Tensors are all of dtype torch.int64. - We have 2 process groups, 2 ranks. - rank 0 passes: - tensor_list = - [[[-1, -1], [-1, -1]], - [[-1, -1], [-1, -1]]] - tensor = [[1, 1], [2, 2]] - rank 1 passes: - tensor_list = - [[[-1, -1], [-1, -1]], - [[-1, -1], [-1, -1]]] - tensor = [[3, 3], [4, 4]] - both rank 0 and 1 get: - tensor_list = - [[[1, 1], [2, 2]], - [[3, 3], [4, 4]]] - - Tensors are all of dtype torch.complex64. - We have 2 process groups, 2 ranks. - rank 0 passes: - tensor_list = - [[[0+0i, 0+0i], [0+0i, 0+0i]], - [[0+0i, 0+0i], [0+0i, 0+0i]]] - tensor = [[1+i, 1+i], [2+2i, 2+2i]] - rank 1 passes: - tensor_list = - [[[0+0i, 0+0i], [0+0i, 0+0i]], - [[0+0i, 0+0i], [0+0i, 0+0i]]] - tensor = [[3+3i, 3+3i], [4+4i, 4+4i]] - both rank 0 and 1 get: - tensor_list = - [[[1+i, 1+i], [2+2i, 2+2i]], - [[3+3i, 3+3i], [4+4i, 4+4i]]] + Examples: + >>> # Tensors are all of dtype torch.int64. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)] + >>> tensor_list + [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> tensor + tensor([1, 2]) # Rank 0 + tensor([3, 4]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2]), tensor([3, 4])] # Rank 0 + [tensor([1, 2]), tensor([3, 4])] # Rank 1 + + >>> # Tensors are all of dtype torch.complex64. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [torch.zero(2, dtype=torch.complex64) for _ in range(2)] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 + >>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) + >>> tensor + tensor([1.+1.j, 2.+2.j]) # Rank 0 + tensor([3.+3.j, 4.+4.j]) # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0 + [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1 """ _check_tensor_list(tensor_list, "tensor_list") From 2c5e8729bf652c10f27ef038037f204854fa1412 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 7 Oct 2020 14:22:24 -0700 Subject: [PATCH 05/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 34 ++++++++++++++----- .../_internal/distributed/distributed_test.py | 12 +++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 66ec944f1e05..2fe526b41314 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -43,6 +43,18 @@ except ImportError: _GLOO_AVAILABLE = False +# Some reduce ops are not supported by complex numbers. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (torch.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN: + return False + return True + class Backend(object): """ @@ -936,7 +948,7 @@ def all_reduce(tensor, None, if not async_op or if not part of the group Examples: - >>> # Tensors are all of dtype torch.int64. + >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> tensor @@ -947,9 +959,9 @@ def all_reduce(tensor, tensor([4, 6]) # Rank 0 tensor([4, 6]) # Rank 1 - >>> # Tensors are all of dtype torch.complex64. + >>> # All tensors below are of torch.cdouble type. >>> # We have 2 process groups, 2 ranks. - >>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j]) # Rank 0 tensor([3.+3.j, 4.+4.j]) # Rank 1 @@ -963,7 +975,10 @@ def all_reduce(tensor, if _rank_not_in_group(group): return - tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + if tensor.is_complex(): + if not supports_complex(op): + raise RuntimeError(f"{op} is unsupported on complex tensors") + tensor = torch.view_as_real(tensor) opts = AllreduceOptions() opts.reduceOp = op @@ -1018,6 +1033,9 @@ def all_reduce_coalesced(tensors, if _rank_not_in_group(group): return + if any([t.is_complex() for t in tensors]) and not supports_complex(op): + raise RuntimeError(f"{op} is unsupported on complex tensors") + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] opts = AllreduceCoalescedOptions() @@ -1451,7 +1469,7 @@ def all_gather(tensor_list, None, if not async_op or if not part of the group Examples: - >>> # Tensors are all of dtype torch.int64. + >>> # All tensors below are of torch.int64 dtype. >>> # We have 2 process groups, 2 ranks. >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)] >>> tensor_list @@ -1465,12 +1483,12 @@ def all_gather(tensor_list, [tensor([1, 2]), tensor([3, 4])] # Rank 0 [tensor([1, 2]), tensor([3, 4])] # Rank 1 - >>> # Tensors are all of dtype torch.complex64. + >>> # All tensors below are of torch.cdouble dtype. >>> # We have 2 process groups, 2 ranks. - >>> tensor_list = [torch.zero(2, dtype=torch.complex64) for _ in range(2)] + >>> tensor_list = [torch.zero(2, dtype=torch.cdouble) for _ in range(2)] >>> tensor_list [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 - >>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j]) # Rank 0 tensor([3.+3.j, 4.+4.j]) # Rank 1 diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 57e5b8367556..d98e629b4b98 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1071,6 +1071,12 @@ def test_all_reduce_sum_complex(self): dtype=torch.cfloat, ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_max_complex_unsupported(self): + group, group_id, rank = self._init_global_test() + with self.assertRaisesRegex(RuntimeError, "is unsupported on complex tensors"): + dist.all_reduce(_build_tensor(1, dtype=torch.cfloat), dist.ReduceOp.MAX, group_id) + @unittest.skipIf( BACKEND != "gloo", "Only Gloo backend will have CUDA allReduce tested", @@ -1263,6 +1269,12 @@ def _all_reduce_coalesced_max_test_cases(group_size): [torch.float, torch.float], ) + @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") + def test_all_reduce_coalesced_max_complex_unsupported(self): + group, group_id, rank = self._init_global_test() + with self.assertRaisesRegex(RuntimeError, "is unsupported on complex tensors"): + dist.all_reduce_coalesced([_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id) + def _test_all_reduce_coalesced_helper( self, group, From 74cb182b75cabb3f1a955d33eb6747845c5a5132 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 7 Oct 2020 14:24:46 -0700 Subject: [PATCH 06/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 2fe526b41314..da2701d7e5b5 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -933,6 +933,7 @@ def all_reduce(tensor, After the call ``tensor`` is going to be bitwise identical in all processes. Complex tensors are supported. + They are not supported for the Max and Min ReduceOps, however. Arguments: tensor (Tensor): Input and output of the collective. The function @@ -1014,6 +1015,7 @@ def all_reduce_coalesced(tensors, in all processes. Complex tensors are supported. + They are not supported for the Max and Min ReduceOps, however. Arguments: tensors (List[Tensor]): Input and output of the collective. The function From cb38d26c10c2e3112231010ad6633cdba82fd9c7 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 8 Oct 2020 10:10:11 -0700 Subject: [PATCH 07/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/csrc/distributed/c10d/init.cpp | 2 ++ torch/distributed/distributed_c10d.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 38a1811692c2..362c3f50a2ee 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -220,6 +220,8 @@ An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, Note that ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when using the ``NCCL`` backend. +Additionally, ``MAX`` and ``MIN`` are both not supported for complex tensors. + The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. They are used in specifying strategies for reduction collectives, e.g., :func:`reduce`, :func:`all_reduce_multigpu`, etc.)") diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index da2701d7e5b5..2fe526b41314 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -933,7 +933,6 @@ def all_reduce(tensor, After the call ``tensor`` is going to be bitwise identical in all processes. Complex tensors are supported. - They are not supported for the Max and Min ReduceOps, however. Arguments: tensor (Tensor): Input and output of the collective. The function @@ -1015,7 +1014,6 @@ def all_reduce_coalesced(tensors, in all processes. Complex tensors are supported. - They are not supported for the Max and Min ReduceOps, however. Arguments: tensors (List[Tensor]): Input and output of the collective. The function From ef8c0312958bfdc5e34386dbd7e59dd34114e1cb Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 8 Oct 2020 11:03:43 -0700 Subject: [PATCH 08/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 2fe526b41314..982f63b37688 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -959,9 +959,9 @@ def all_reduce(tensor, tensor([4, 6]) # Rank 0 tensor([4, 6]) # Rank 1 - >>> # All tensors below are of torch.cdouble type. + >>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. - >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j) + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j]) # Rank 0 tensor([3.+3.j, 4.+4.j]) # Rank 1 @@ -1483,12 +1483,12 @@ def all_gather(tensor_list, [tensor([1, 2]), tensor([3, 4])] # Rank 0 [tensor([1, 2]), tensor([3, 4])] # Rank 1 - >>> # All tensors below are of torch.cdouble dtype. + >>> # All tensors below are of torch.cfloat dtype. >>> # We have 2 process groups, 2 ranks. - >>> tensor_list = [torch.zero(2, dtype=torch.cdouble) for _ in range(2)] + >>> tensor_list = [torch.zero(2, dtype=torch.float) for _ in range(2)] >>> tensor_list [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 - >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j) + >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j]) # Rank 0 tensor([3.+3.j, 4.+4.j]) # Rank 1 From e8fcc0bf9cfbc13b6796aa1ba0a2fef59109f464 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 8 Oct 2020 11:41:10 -0700 Subject: [PATCH 09/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 982f63b37688..de1934003a54 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1485,7 +1485,7 @@ def all_gather(tensor_list, >>> # All tensors below are of torch.cfloat dtype. >>> # We have 2 process groups, 2 ranks. - >>> tensor_list = [torch.zero(2, dtype=torch.float) for _ in range(2)] + >>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)] >>> tensor_list [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) From e048c7560f326b8904fe956bfeb6068672f9d4c0 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 8 Oct 2020 12:26:45 -0700 Subject: [PATCH 10/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/csrc/distributed/c10d/init.cpp | 2 +- torch/distributed/distributed_c10d.py | 4 ++-- torch/testing/_internal/distributed/distributed_test.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 362c3f50a2ee..b1d90dc558d8 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -220,7 +220,7 @@ An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, Note that ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when using the ``NCCL`` backend. -Additionally, ``MAX`` and ``MIN`` are both not supported for complex tensors. +Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. They are used in specifying strategies for reduction collectives, e.g., diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index de1934003a54..ad63de4ed0b5 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -51,7 +51,7 @@ # We'd like calls to unsupported ops to error out accordingly, # rather than returning garbage values. def supports_complex(reduceOp: ReduceOp) -> bool: - if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN: + if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN or reduceOp == ReduceOp.PRODUCT: return False return True @@ -1034,7 +1034,7 @@ def all_reduce_coalesced(tensors, return if any([t.is_complex() for t in tensors]) and not supports_complex(op): - raise RuntimeError(f"{op} is unsupported on complex tensors") + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index d98e629b4b98..7cf439e9c310 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1074,7 +1074,7 @@ def test_all_reduce_sum_complex(self): @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_max_complex_unsupported(self): group, group_id, rank = self._init_global_test() - with self.assertRaisesRegex(RuntimeError, "is unsupported on complex tensors"): + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): dist.all_reduce(_build_tensor(1, dtype=torch.cfloat), dist.ReduceOp.MAX, group_id) @unittest.skipIf( From 3d6a6929b57c40d69b40080513b1166886667d4e Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 9 Oct 2020 06:57:58 -0700 Subject: [PATCH 11/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 197c52e6d635..0285a5ed1563 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -52,9 +52,8 @@ # We'd like calls to unsupported ops to error out accordingly, # rather than returning garbage values. def supports_complex(reduceOp: ReduceOp) -> bool: - if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN or reduceOp == ReduceOp.PRODUCT: - return False - return True + denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT] + return reduceOp not in denyList class Backend(object): From 8ddb50aa4d55f684be6917aa818bf5e3565689ee Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 9 Oct 2020 07:23:18 -0700 Subject: [PATCH 12/12] Update on "adding complex support for distributed functions and . fix #45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned] --- torch/distributed/distributed_c10d.py | 2 +- torch/testing/_internal/distributed/distributed_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 0285a5ed1563..1269d57721fd 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1077,7 +1077,7 @@ def all_reduce(tensor, if tensor.is_complex(): if not supports_complex(op): - raise RuntimeError(f"{op} is unsupported on complex tensors") + raise RuntimeError(f"all_reduce does not support {op} on complex tensors") tensor = torch.view_as_real(tensor) opts = AllreduceOptions() diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 5ce29021153a..2cbf0c81f24c 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1509,7 +1509,7 @@ def _all_reduce_coalesced_max_test_cases(group_size): @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") def test_all_reduce_coalesced_max_complex_unsupported(self): group, group_id, rank = self._init_global_test() - with self.assertRaisesRegex(RuntimeError, "is unsupported on complex tensors"): + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): dist.all_reduce_coalesced([_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id) def _test_all_reduce_coalesced_helper(