Skip to content

Commit

Permalink
Enable product for bool tensor (#48637)
Browse files Browse the repository at this point in the history
Summary:
Fixes #48351

Pull Request resolved: #48637

Reviewed By: mrshenli

Differential Revision: D25658596

Pulled By: mruberry

fbshipit-source-id: ff3ada74b6d281c8e4753ed38339a1c036f722ee
  • Loading branch information
Kiyosora authored and facebook-github-bot committed Dec 21, 2020
1 parent 49c9994 commit 983bfc7
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
18 changes: 14 additions & 4 deletions aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Expand Up @@ -158,13 +158,23 @@ static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_s
}

static void prod_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] {
// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
if (iter.dtype() == ScalarType::Bool) {
using scalar_t = bool;
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) { return a * b; },
[=](scalar_t a, scalar_t b) -> scalar_t { return a && b; },
[=](Vec256<scalar_t> a, Vec256<scalar_t> b) { return a && b; },
/*identity=*/1);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
[=](Vec256 <scalar_t> a, Vec256 <scalar_t> b) { return a * b; },
/*identity=*/1);
});
}
}

static void norm_kernel_tensor_iterator_impl(
Expand Down
13 changes: 12 additions & 1 deletion aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
Expand Up @@ -35,6 +35,17 @@ struct prod_functor {
}
};

// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
template <>
struct prod_functor<bool> {
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<bool, bool>(
iter, func_wrapper<bool>([] GPU_LAMBDA(bool a, bool b) -> bool {
return a && b;
}), 1);
}
};

// The function `reduce_dispatch` below dispatches to the kernel based
// on the type of `iter`. It takes care of the common logic
// for handling Half-Precision floating types.
Expand Down Expand Up @@ -88,7 +99,7 @@ static void nansum_kernel_cuda(TensorIterator& iter) {

static void prod_kernel_cuda(TensorIterator& iter) {
auto general_dispatcher = [](TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, iter.dtype(), "prod_cuda", [&]() {
prod_functor<scalar_t>{}(iter);
});
};
Expand Down
11 changes: 11 additions & 0 deletions test/test_reductions.py
Expand Up @@ -814,6 +814,17 @@ def test_prod(self, device, dtype):
torch.prod(x, 1, out=res2)
self.assertEqual(res1, res2)

def test_prod_bool(self, device):
vals = [[True, True], [True, False], [False, False], []]
for val in vals:
result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
expect = np.prod(np.array(val), dtype=np.bool)
self.assertEqual(result, expect)

result = torch.prod(torch.tensor(val, device=device)).item()
expect = np.prod(np.array(val))
self.assertEqual(result, expect)

@onlyCPU
def test_max_mixed_devices(self, device):
a = torch.randn(10, device=device)
Expand Down

0 comments on commit 983bfc7

Please sign in to comment.