Skip to content

Commit

Permalink
[ATen] Support multi dim any and all reductions
Browse files Browse the repository at this point in the history
ghstack-source-id: f10563e5c7e840a15c7b205d2e6466c5c00d9dd0
Pull Request resolved: #110310
  • Loading branch information
peterbell10 committed Oct 11, 2023
1 parent f3aba45 commit 723a397
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 70 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/functorch/BatchRulesReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
REDUCTION_WITH_KEEPDIM_ARG(aminmax);
m.impl("all", all_decomp);
REDUCTION_WITH_KEEPDIM_ARG(all.dim);
REDUCTION_WITH_KEEPDIM_ARG(all.dims);
m.impl("any", any_decomp);
REDUCTION_WITH_KEEPDIM_ARG(any.dim);
REDUCTION_WITH_KEEPDIM_ARG(any.dims);
REDUCTION_WITH_KEEPDIM_ARG(argmax);
REDUCTION_WITH_KEEPDIM_ARG(argmin);
m.impl("bucketize.Tensor", bucketize_decomp_Tensor);
Expand Down
99 changes: 90 additions & 9 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
#include <ATen/ops/_is_any_true_native.h>
#include <ATen/ops/_logcumsumexp.h>
#include <ATen/ops/_logcumsumexp_native.h>
#include <ATen/ops/_sparse_csr_sum.h>
#include <ATen/ops/_sparse_sum.h>
#include <ATen/ops/_sparse_sum_native.h>
#include <ATen/ops/_sparse_csr_sum.h>
#include <ATen/ops/_to_copy.h>
#include <ATen/ops/add.h>
#include <ATen/ops/all_meta.h>
#include <ATen/ops/all_native.h>
Expand Down Expand Up @@ -85,6 +86,7 @@
#include <ATen/ops/nansum_native.h>
#include <ATen/ops/narrow.h>
#include <ATen/ops/native_norm.h>
#include <ATen/ops/ne.h>
#include <ATen/ops/norm.h>
#include <ATen/ops/norm_meta.h>
#include <ATen/ops/norm_native.h>
Expand Down Expand Up @@ -184,26 +186,32 @@ static void allany_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
IntArrayRef dims,
OptionalIntArrayRef dims,
bool keepdim) {
const auto& result = meta.maybe_get_output();
check_result_is_bytebool(name, self, result);
auto out_dtype = get_result_or_bytebool_dtype(self, result);
resize_reduction(meta, self, dims, keepdim, out_dtype);
resize_reduction(meta, self, dims, keepdim, out_dtype, /*allow_empty_dims=*/true);
}

TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
allany_meta(*this, "all", self, dim, keepdim);
}

TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
TORCH_META_FUNC2(all, dims)(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
allany_meta(*this, "all", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(all)(const Tensor& self) {
allany_meta(*this, "all", self, {}, false);
}

TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
allany_meta(*this, "any", self, dim, keepdim);
}

TORCH_META_FUNC2(any, dims)(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
allany_meta(*this, "any", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}

TORCH_META_FUNC(any)(const Tensor& self) {
Expand Down Expand Up @@ -1497,7 +1505,7 @@ Tensor norm(const Tensor& self, const Scalar& p) {
inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
OptionalIntArrayRef dims,
bool keepdim) {
if (self.is_cuda()) {
// As CUDA supports dynamic type casting, we use this overload of
Expand All @@ -1514,7 +1522,7 @@ template <int identity, typename Stub>
inline void allany_impl(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
OptionalIntArrayRef dims,
bool keepdim,
Stub& stub) {
if (self.numel() == 0) {
Expand All @@ -1532,6 +1540,11 @@ TORCH_IMPL_FUNC(all_out)
allany_impl<1>(self, result, dim, keepdim, and_stub);
}

TORCH_IMPL_FUNC(all_dims_out)
(const Tensor& self, OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
allany_impl<1>(self, result, dim, keepdim, and_stub);
}

TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
allany_impl<1>(self, result, {}, false, and_stub);
}
Expand All @@ -1541,10 +1554,78 @@ TORCH_IMPL_FUNC(any_out)
allany_impl<0>(self, result, dim, keepdim, or_stub);
}

TORCH_IMPL_FUNC(any_dims_out)
(const Tensor& self, OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
allany_impl<0>(self, result, dim, keepdim, or_stub);
}

TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
allany_impl<0>(self, result, {}, false, or_stub);
}

template <bool is_all>
Tensor allany_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
// Default implementation in terms of all-reduce or single dim reduce
if (!dim) {
Tensor out;
if constexpr (is_all) {
out = self.all();
} else {
out = self.any();
}

if (keepdim) {
DimVector out_shape(self.dim(), 1);
return out.expand(out_shape);
}
return out;
}

if (dim->size() == 0) {
if (self.scalar_type() == kByte) {
// Convert to a 1 or 0 mask
auto out = at::empty_like(self);
return at::ne_outf(self, 0, out);
} else {
return at::_to_copy(self, kBool);
}
}

Tensor out = self;
for (auto d : *dim) {
if constexpr (is_all) {
out = out.all(d, /*keepdim=*/true);
} else {
out = out.any(d, /*keepdim=*/true);
}
}
return keepdim ? out : out.squeeze(*dim);
}

Tensor all_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
return allany_dims_default<true>(self, dim, keepdim);
}

Tensor any_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
return allany_dims_default<false>(self, dim, keepdim);
}

Tensor& all_dims_out_default(
const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
TORCH_CHECK(self.device() == result.device(), "all: Output must be on the same device as input");
auto tmp = self.all(dim, keepdim);
at::native::resize_output(result, tmp.sizes());
return result.copy_(tmp);
}

Tensor& any_dims_out_default(
const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
TORCH_CHECK(self.device() == result.device(), "any: Output must be on the same device as input");
auto tmp = self.any(dim, keepdim);
at::native::resize_output(result, tmp.sizes());
return result.copy_(tmp);
}

TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
auto iter =
meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
Expand Down
14 changes: 8 additions & 6 deletions aten/src/ATen/native/ReduceOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
}
}

static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) {
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
DimMask mask;
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
if (dims.empty()) {
if (dims.empty() && !allow_empty_dims) {
mask = DimMask().flip();
} else {
mask = at::dim_list_to_bitset(dims, ndim);
Expand Down Expand Up @@ -351,8 +351,9 @@ namespace at::meta {
static C10_UNUSED DimVector get_reduction_shape(
const Tensor& self,
IntArrayRef dims,
bool keepdim) {
auto mask = native::make_dim_mask(dims, self.dim());
bool keepdim,
bool allow_empty_dims=false) {
auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
return native::shape_from_dim_mask(self, mask, keepdim);
}

Expand All @@ -361,10 +362,11 @@ static void resize_reduction(
const Tensor& self,
OptionalIntArrayRef opt_dims,
bool keepdim,
ScalarType out_dtype) {
ScalarType out_dtype,
bool allow_empty_dims=false) {
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
maybe_wrap_dims(dims_, self.dim());
auto shape = get_reduction_shape(self, dims_, keepdim);
auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(), self, dims_, keepdim);
Expand Down
37 changes: 33 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -681,15 +681,29 @@
structured_delegate: all.out
variants: function, method

- func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: all.dims_out
variants: function, method
cpp_no_default_args: ['dim']
dispatch:
CompositeExplicitAutograd: all_dims_default

- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
precomputed:
- dim -> int dim
dispatch:
CPU, CUDA: all_out
MPS: all_out_mps

- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
dispatch:
CPU, CUDA: all_dims_out
CompositeExplicitAutograd: all_dims_out_default
cpp_no_default_args: ['dim']

- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
Expand All @@ -709,15 +723,30 @@
variants: function, method
tags: core

- func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.dims_out
variants: function, method
cpp_no_default_args: ['dim']
tags: core
dispatch:
CompositeExplicitAutograd: any_dims_default

- func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
precomputed:
- dim -> int dim
dispatch:
CPU, CUDA: any_out
MPS: any_out_mps

- func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
dispatch:
CPU, CUDA: any_dims_out
CompositeExplicitAutograd: any_dims_out_default
cpp_no_default_args: ['dim']

- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
Expand Down
2 changes: 2 additions & 0 deletions test/expect/HasDecompTest.test_aten_core_operators.expect
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ aten::amin.out
aten::any
aten::any.all_out
aten::any.dim
aten::any.dims
aten::any.dims_out
aten::any.out
aten::arange.out
aten::arange.start_out
Expand Down
10 changes: 10 additions & 0 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ def skip_torchlib_forward_compatibility(
"addmm", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Addmm")
),
xfail_torchlib_forward_compatibility(
"all",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.all.dims"),
github_issue="https://github.com/microsoft/onnxscript/pull/1084"
),
xfail(
"allclose", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES,
reason=onnx_test_common.reason_dynamo_does_not_support("Allclose")
Expand All @@ -257,6 +262,11 @@ def skip_torchlib_forward_compatibility(
"amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16")
),
xfail_torchlib_forward_compatibility(
"any",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.any.dims"),
github_issue="https://github.com/microsoft/onnxscript/pull/1084"
),
xfail(
"arange",
dtypes=(torch.uint8,),
Expand Down
7 changes: 2 additions & 5 deletions test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@
import parameterized

import torch
from onnx_test_common import MAX_ONNX_OPSET_VERSION, MIN_ONNX_OPSET_VERSION
from pytorch_test_common import (
skipIfNoBFloat16Cuda,
skipIfNoCuda,
skipIfUnsupportedMinOpsetVersion,
skipScriptTest,
)
from test_pytorch_onnx_onnxruntime import (
_parameterized_class_attrs_and_values,
MAX_ONNX_OPSET_VERSION,
MIN_ONNX_OPSET_VERSION,
)
from test_pytorch_onnx_onnxruntime import _parameterized_class_attrs_and_values
from torch.cuda.amp import autocast
from torch.testing._internal import common_utils

Expand Down
8 changes: 5 additions & 3 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def mps_ops_grad_modifier(ops):
'cumprod': [torch.float32],
}

XPASSLIST_GRAD = {
SKIPLIST_GRAD = {
'nn.functional.pairwise_distance': [torch.float16],
# failed assertion `destination datatype must be fp32'
'nn.functional.conv1d': [torch.float16],
Expand All @@ -177,6 +177,8 @@ def mps_ops_grad_modifier(ops):
'nn.functional.conv_transpose1d': [torch.float16],
'nn.functional.conv_transpose2d': [torch.float16],
'nn.functional.conv_transpose3d': [torch.float16],
# Segfaults
'all': [torch.float16],
}

MACOS_13_3_XFAILLIST_GRAD = {
Expand All @@ -198,10 +200,10 @@ def addDecorator(op, d) -> None:
unittest.expectedFailure,
dtypes=XFAILLIST_GRAD[key]))

if key in XPASSLIST_GRAD:
if key in SKIPLIST_GRAD:
addDecorator(op, DecorateInfo(
unittest.skip,
dtypes=XPASSLIST_GRAD[key]))
dtypes=SKIPLIST_GRAD[key]))

if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
addDecorator(op, DecorateInfo(
Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@
- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
output_differentiability: [False]

- name: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
output_differentiability: [False]

- name: _is_all_true(Tensor self) -> Tensor
self: non_differentiable

Expand All @@ -307,6 +310,9 @@
- name: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
output_differentiability: [False]

- name: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
output_differentiability: [False]

- name: acosh(Tensor self) -> Tensor
# Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case)
self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()"
Expand Down
Loading

0 comments on commit 723a397

Please sign in to comment.