diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 10437f51d4b4..5f96e01ab319 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -158,13 +158,23 @@ static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_s } static void prod_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { + // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] + if (iter.dtype() == ScalarType::Bool) { + using scalar_t = bool; binary_kernel_reduce_vec( iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, - [=](Vec256 a, Vec256 b) { return a * b; }, + [=](scalar_t a, scalar_t b) -> scalar_t { return a && b; }, + [=](Vec256 a, Vec256 b) { return a && b; }, /*identity=*/1); - }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { + binary_kernel_reduce_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, + [=](Vec256 a, Vec256 b) { return a * b; }, + /*identity=*/1); + }); + } } static void norm_kernel_tensor_iterator_impl( diff --git a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu index 1732a2ad73d7..9919b0f0eac4 100644 --- a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu @@ -35,6 +35,17 @@ struct prod_functor { } }; +// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] +template <> +struct prod_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper([] GPU_LAMBDA(bool a, bool b) -> bool { + return a && b; + }), 1); + } +}; + // The function `reduce_dispatch` below dispatches to the kernel based // on the type of `iter`. It takes care of the common logic // for handling Half-Precision floating types. @@ -88,7 +99,7 @@ static void nansum_kernel_cuda(TensorIterator& iter) { static void prod_kernel_cuda(TensorIterator& iter) { auto general_dispatcher = [](TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, iter.dtype(), "prod_cuda", [&]() { prod_functor{}(iter); }); }; diff --git a/test/test_reductions.py b/test/test_reductions.py index 7c877d822142..917d469c5ee6 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -814,6 +814,17 @@ def test_prod(self, device, dtype): torch.prod(x, 1, out=res2) self.assertEqual(res1, res2) + def test_prod_bool(self, device): + vals = [[True, True], [True, False], [False, False], []] + for val in vals: + result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item() + expect = np.prod(np.array(val), dtype=np.bool) + self.assertEqual(result, expect) + + result = torch.prod(torch.tensor(val, device=device)).item() + expect = np.prod(np.array(val)) + self.assertEqual(result, expect) + @onlyCPU def test_max_mixed_devices(self, device): a = torch.randn(10, device=device)