Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] torch.{all, any} : Extend Dtype Support #44790

Closed
Closed
10 changes: 2 additions & 8 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,6 @@ Tensor all(const Tensor& self) {
"all only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"all only supports strided layout, got: ", self.layout());
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");

Tensor result = at::empty({0}, self.options());
auto iter = make_reduction(
Expand All @@ -714,8 +712,7 @@ Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
"all only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"all only supports strided layout, got: ", self.layout());
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");

dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 1, dim, keepdim)) {
return result;
Expand All @@ -741,8 +738,6 @@ Tensor any(const Tensor& self) {
"any only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse,
"any only supports strided AND sparse layout, got: ", self.layout());
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");

Tensor result = at::empty({0}, self.options());
auto iter = make_reduction(
Expand All @@ -760,8 +755,7 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
"any only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"any only supports strided layout, got: ", self.layout());
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");

dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
return result;
Expand Down
50 changes: 50 additions & 0 deletions aten/src/ATen/native/SharedReduceOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,56 @@ struct NanSumOps {
#endif
};

template <typename acc_t>
struct AndOps {
inline C10_DEVICE acc_t reduce(acc_t a, acc_t b, int64_t /*idx*/) const {
return static_cast<bool>(a) && static_cast<bool>(b);
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return static_cast<bool>(a) && static_cast<bool>(b);
}

inline C10_DEVICE acc_t project(acc_t a) const {
return a;
}

static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
}
#endif
};

template <typename acc_t>
struct OrOps {
inline C10_DEVICE acc_t reduce(acc_t a, acc_t b, int64_t /*idx*/) const {
return static_cast<bool>(a) || static_cast<bool>(b);
}

inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return static_cast<bool>(a) || static_cast<bool>(b);
}

inline C10_DEVICE acc_t project(acc_t a) const {
return a;
}

static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}

#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
return WARP_SHFL_DOWN(data, offset);
}
#endif
};

namespace detail {

template <typename scalar_t>
Expand Down
78 changes: 46 additions & 32 deletions aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,41 +232,55 @@ static void norm_kernel_tensor_iterator_impl(
}

static void and_kernel_impl(TensorIterator& iter) {
binary_kernel_reduce_vec(
iter,
[=](uint8_t a, uint8_t b) -> uint8_t { return a && b; },
[=](Vec256<uint8_t> a, Vec256<uint8_t> b) {
// Adding the implementation here instead of in vec256_base to avoid
// return value inconsistency. Other comparison operators in vec256_base
// return -1/0 (all bit 1 / all bit 0) as true/false to follow the AVX2
// convention. This would be convenient when combined with other
// vectorized operations. For example, one can use the logical operation
// results as a mask for a bit operation to retrieve/reset multiple
// elements in a vector.
//
// In this method, users would expect, e.g., all(), to return 1/0 as
// true/false.
Vec256<uint8_t> c = Vec256<uint8_t>();
for (int i = 0; i != Vec256<uint8_t>::size(); i++) {
c[i] = a[i] && b[i];
}
return c;
},
/*ident=*/true);
if (c10::isIntegralType(iter.dtype(), /*includeBool=*/true)) {
binary_kernel_reduce_vec(
iter,
[=](uint8_t a, uint8_t b) -> uint8_t { return (a && b) ? 1 : 0; },
[=](Vec256<uint8_t> a, Vec256<uint8_t> b) {
// Adding the implementation here instead of in vec256_base to avoid
// return value inconsistency. Other comparison operators in
// vec256_base return -1/0 (all bit 1 / all bit 0) as true/false to
// follow the AVX2 convention. This would be convenient when combined
// with other vectorized operations. For example, one can use the
// logical operation results as a mask for a bit operation to
// retrieve/reset multiple elements in a vector.
//
// In this method, users would expect, e.g., all(), to return 1/0 as
// true/false.
Vec256<uint8_t> c = Vec256<uint8_t>();
for (int i = 0; i != Vec256<uint8_t>::size(); i++) {
c[i] = (a[i] && b[i]) ? 1 : 0;
}
return c;
},
/*ident=*/true);
} else {
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "and_kernel", [&]() {
binary_kernel_reduce(
iter, AndOps<scalar_t>(), static_cast<scalar_t>(true));
});
}
}

static void or_kernel_impl(TensorIterator& iter) {
binary_kernel_reduce_vec(
iter,
[=](uint8_t a, uint8_t b) -> uint8_t { return a || b; },
[=](Vec256<uint8_t> a, Vec256<uint8_t> b) {
Vec256<uint8_t> c = Vec256<uint8_t>();
for (int i = 0; i != Vec256<uint8_t>::size(); i++) {
c[i] = a[i] || b[i];
}
return c;
},
/*ident=*/false);
if (c10::isIntegralType(iter.dtype(), /*includeBool=*/true)) {
binary_kernel_reduce_vec(
iter,
[=](uint8_t a, uint8_t b) -> uint8_t { return (a || b) ? 1 : 0; },
[=](Vec256<uint8_t> a, Vec256<uint8_t> b) {
Vec256<uint8_t> c = Vec256<uint8_t>();
for (int i = 0; i != Vec256<uint8_t>::size(); i++) {
c[i] = (a[i] || b[i]) ? 1 : 0;
}
return c;
},
/*ident=*/false);
} else {
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "or_kernel", [&]() {
binary_kernel_reduce(
iter, OrOps<scalar_t>(), static_cast<scalar_t>(false));
});
}
}

template<typename scalar_t>
Expand Down
24 changes: 16 additions & 8 deletions aten/src/ATen/native/cuda/ReduceLogicKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,25 @@
namespace at { namespace native {

void and_kernel_cuda(TensorIterator& iter) {
gpu_reduce_kernel<uint8_t, uint8_t>(
iter, func_wrapper<uint8_t> ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t {
return a && b;
}), true);
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "and_kernel", [&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
func_wrapper<scalar_t>([] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(static_cast<bool>(a) && static_cast<bool>(b));
}),
static_cast<scalar_t>(true));
});
}

void or_kernel_cuda(TensorIterator& iter) {
gpu_reduce_kernel<uint8_t, uint8_t>(
iter, func_wrapper<uint8_t> ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t {
return a || b;
}), false);
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "or_kernel", [&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
func_wrapper<scalar_t>([] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(static_cast<bool>(a) || static_cast<bool>(b));
}),
static_cast<scalar_t>(false));
});
}

REGISTER_DISPATCH(and_stub, &and_kernel_cuda);
Expand Down
44 changes: 44 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19486,6 +19486,50 @@ def test_dstack(self, device, dtype):
expected = np.dstack(np_input)
self.assertEqual(actual, expected)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False,
include_bool=True, include_complex=False)))
def test_all_any_vs_numpy(self, device, dtype):
def _test_all_any(x):
self.compare_with_numpy(torch.all, np.all, x)
self.compare_with_numpy(torch.any, np.any, x)

def _test_all_any_with_dim(x, dim):
torch_fn = partial(torch.all, dim=dim)
np_fn = partial(np.all, axis=dim)
self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=False)

torch_fn = partial(torch.any, dim=dim)
np_fn = partial(np.any, axis=dim)
self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=False)

for ndim in range(5):
shape = self._rand_shape(ndim, 1, 5)
x = self._generate_input(shape, dtype, device, with_extremal=False)
_test_all_any(x)

x = self._generate_input(shape, dtype, device, with_extremal=True)
_test_all_any(x)

x = torch.zeros_like(x)
_test_all_any(x)

x = torch.ones_like(x)
_test_all_any(x)

for dim in range(ndim):
x = self._generate_input(shape, dtype, device, with_extremal=False)
_test_all_any_with_dim(x, dim)

x = self._generate_input(shape, dtype, device, with_extremal=True)
_test_all_any_with_dim(x, dim)

x = torch.zeros_like(x)
_test_all_any_with_dim(x, dim)

x = torch.ones_like(x)
_test_all_any_with_dim(x, dim)

@onlyOnCPUAndCUDA
def test_repeated_dim(self, device):
ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
Expand Down