Skip to content
114 changes: 46 additions & 68 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,18 @@
namespace at {
namespace meta {

void check_all_any(
impl::MetaBase& meta,
ScalarType check_allany_and_get_output_dtype(
const char* name,
const Tensor& self,
int64_t raw_dim,
const Tensor& result,
IntArrayRef dims,
bool keepdim) {
auto dim = at::maybe_wrap_dim(raw_dim, self.dim());
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
self.layout() == Layout::Strided,
name, " only supports strided layout, got: ",
self.layout());

const auto& result = meta.maybe_get_output();
ScalarType out_dtype;

if (result.defined()) {
Expand All @@ -63,17 +61,29 @@ void check_all_any(
}
}

return out_dtype;
}

void check_allany_for_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
int64_t dim,
bool keepdim) {
dim = maybe_wrap_dim(dim, self.dim());
const auto& result = meta.maybe_get_output();
auto out_dtype = check_allany_and_get_output_dtype(name, self, result, dim, keepdim);
auto shape = get_reduction_shape(self, dim, keepdim);
meta.set_output(shape, self.options().dtype(out_dtype));
namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
}

TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any(*this, "all", self, dim, keepdim);
check_allany_for_meta(*this, "all", self, dim, keepdim);
}

TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any(*this, "any", self, dim, keepdim);
check_allany_for_meta(*this, "any", self, dim, keepdim);
}

} // namespace meta
Expand Down Expand Up @@ -1150,18 +1160,6 @@ Tensor norm(const Tensor& self, const Scalar& p) {
return at::native::_norm(self, p);
}

inline TensorIterator get_reduction_iter(
const Tensor& self,
const Tensor& result,
int64_t dim,
bool keepdim) {
if (self.is_cuda()) {
return meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
}
return meta::make_reduction_from_out_ty(
self, result, dim, keepdim, result.scalar_type());
}

// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
Expand All @@ -1178,40 +1176,38 @@ inline const Tensor & _all(const Tensor & result, TensorIterator & iter) {
return result;
}

Tensor all(const Tensor& self) {
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
"all only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"all only supports strided layout, got: ", self.layout());

// Refer [all, any : uint8 compatibility]
Tensor result;
ScalarType out_dtype;
if (self.scalar_type() == ScalarType::Byte){
result = at::empty({0}, self.options());
out_dtype = self.scalar_type();
} else {
result = at::empty({0}, self.options().dtype(kBool));
out_dtype = ScalarType::Bool;
}

inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
bool keepdim) {
if (self.is_cuda()) {
// As CUDA supports dynamic type casting, we use this overload of
// `make_reduction`, which doesn't cast input to the result type i.e. kBool.,
// otherwise we use the overload below which casts the input to kBool (which is
// an extra operation).
auto iter = make_reduction(
"all", result, self, {}, false, self.scalar_type(), out_dtype);
return _all(result, iter);
return meta::make_reduction(self, result, dims, keepdim, self.scalar_type());
}
auto iter =
make_reduction("all", result, self, {}, false, /*out_dtype=*/out_dtype);
return meta::make_reduction_from_out_ty(
self, result, dims, keepdim, result.scalar_type());
}

Tensor all(const Tensor& self) {
Tensor result;

auto out_dtype =
meta::check_allany_and_get_output_dtype("all", self, result, {}, false);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _all(result, iter);
}

TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
auto iter = get_reduction_iter(self, result, dim, keepdim);
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
_all(mut_result, iter);
Expand All @@ -1229,39 +1225,21 @@ inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
}

Tensor any(const Tensor& self) {
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
"any only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse,
"any only supports strided AND sparse layout, got: ", self.layout());

// Refer [all, any : uint8 compatibility]
Tensor result;
ScalarType out_dtype;
if (self.scalar_type() == ScalarType::Byte){
result = at::empty({0}, self.options());
out_dtype = self.scalar_type();
} else {
result = at::empty({0}, self.options().dtype(kBool));
out_dtype = ScalarType::Bool;
}

if (self.is_cuda()) {
// As CUDA supports dynamic type casting, we use this overload of
// `make_reduction`, which doesn't cast input to the result type i.e. kBool.,
// otherwise we use the overload below which casts the input to kBool (which is
// an extra operation).
auto iter = make_reduction(
"any", result, self, {}, false, self.scalar_type(), out_dtype);
return _any(result, iter);
}
auto iter =
make_reduction("any", result, self, {}, false, /*out_dtype=*/out_dtype);
auto out_dtype =
meta::check_allany_and_get_output_dtype("any", self, result, {}, false);
auto shape = meta::get_reduction_shape(self, {}, false);

result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);

return _any(result, iter);
}

TORCH_IMPL_FUNC(any_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
auto iter = get_reduction_iter(self, result, dim, keepdim);
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
_any(mut_result, iter);
Expand Down