From 1f791c06f0d61f25aa2273ccf15cc65c3073d51c Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 14 Oct 2020 08:46:10 -0700 Subject: [PATCH] adding BAND/BOR/BXOR reduce ops to unsupported list for complex numbers. added tests (#46270) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46270 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D24284702 Pulled By: bdhirsh fbshipit-source-id: 7e6c3fce83a4367808a638f0400999399b2c35b0 --- torch/distributed/distributed_c10d.py | 3 ++- torch/testing/_internal/distributed/distributed_test.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index e4f08b6d697b..db81a0685788 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -52,7 +52,8 @@ # We'd like calls to unsupported ops to error out accordingly, # rather than returning garbage values. def supports_complex(reduceOp: ReduceOp) -> bool: - denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT] + denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT, + ReduceOp.BAND, ReduceOp.BOR, ReduceOp.BXOR] return reduceOp not in denyList diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f1c0e2a3a4a8..c10a834b4ca0 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1303,10 +1303,13 @@ def test_all_reduce_sum_complex(self): ) @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") - def test_all_reduce_max_complex_unsupported(self): + def test_all_reduce_complex_unsupported_ops(self): + unsupported_ops = [dist.ReduceOp.MAX, dist.ReduceOp.MIN, dist.ReduceOp.PRODUCT, + dist.ReduceOp.BAND, dist.ReduceOp.BOR, dist.ReduceOp.BXOR] group, group_id, rank = self._init_global_test() - with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): - dist.all_reduce(_build_tensor(1, dtype=torch.cfloat), dist.ReduceOp.MAX, group_id) + for unsupported_op in unsupported_ops: + with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"): + dist.all_reduce(_build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id) @unittest.skipIf( BACKEND != "gloo",