diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 620908b5b79b..c706f2831e7d 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -92,13 +92,7 @@ ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result } } -void check_all_any(const char* name, const Tensor& self, const Tensor& result) { - // Refer [all, any : uint8 compatibility] - TORCH_CHECK( - self.layout() == Layout::Strided, - name, " only supports strided layout, got: ", - self.layout()); - +void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) { if (result.defined()) { // Refer [all, any : uint8 compatibility] TORCH_CHECK( @@ -109,20 +103,36 @@ void check_all_any(const char* name, const Tensor& self, const Tensor& result) { } } +void allany_meta( + impl::MetaBase& meta, + const char* name, + const Tensor& self, + IntArrayRef dims, + bool keepdim) { + const auto& result = meta.maybe_get_output(); + check_result_is_bytebool(name, self, result); + auto out_dtype = get_result_or_bytebool_dtype(self, result); + resize_reduction(meta, self, dims, keepdim, out_dtype); +} + TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) { - check_all_any("all", self, maybe_get_output()); - auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output()); - resize_reduction(*this, self, dim, keepdim, out_dtype); + allany_meta(*this, "all", self, dim, keepdim); return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim())); } +TORCH_META_FUNC(all)(const Tensor& self) { + allany_meta(*this, "all", self, {}, false); +} + TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) { - check_all_any("any", self, maybe_get_output()); - auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output()); - resize_reduction(*this, self, dim, keepdim, out_dtype); + allany_meta(*this, "any", self, dim, keepdim); return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim())); } +TORCH_META_FUNC(any)(const Tensor& self) { + allany_meta(*this, "any", self, {}, false); +} + void check_argmax_argmin( const char* name, const Tensor& self, @@ -1325,26 +1335,15 @@ inline TensorIterator get_allany_iter( self, result, dims, keepdim, result.scalar_type()); } -Tensor all(const Tensor& self) { - Tensor result; - - meta::check_all_any("all", self, result); - auto out_dtype = meta::get_result_or_bytebool_dtype(self, result); - auto shape = meta::get_reduction_shape(self, {}, false); - - result = at::empty(shape, self.options().dtype(out_dtype)); - auto iter = get_allany_iter(self, result, {}, false); - - return _all(result, iter); -} - TORCH_IMPL_FUNC(all_out) (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) { auto iter = get_allany_iter(self, result, dim, keepdim); - auto mut_result = const_cast(result); - if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) { - _all(mut_result, iter); - } + _all(result, iter); +} + +TORCH_IMPL_FUNC(all_full_out)(const Tensor& self, const Tensor& result) { + auto iter = get_allany_iter(self, result, {}, false); + _all(result, iter); } inline const Tensor & _any(const Tensor & result, TensorIterator & iter) { @@ -1357,29 +1356,18 @@ inline const Tensor & _any(const Tensor & result, TensorIterator & iter) { return result; } -Tensor any(const Tensor& self) { - Tensor result; - - meta::check_all_any("any", self, result); - auto out_dtype = meta::get_result_or_bytebool_dtype(self, result); - auto shape = meta::get_reduction_shape(self, {}, false); - - result = at::empty(shape, self.options().dtype(out_dtype)); - auto iter = get_allany_iter(self, result, {}, false); - - return _any(result, iter); -} - TORCH_IMPL_FUNC(any_out) (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) { auto iter = get_allany_iter(self, result, dim, keepdim); - auto mut_result = const_cast(result); - if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) { - _any(mut_result, iter); - } + _any(result, iter); +} + +TORCH_IMPL_FUNC(any_full_out)(const Tensor& self, const Tensor& result) { + auto iter = get_allany_iter(self, result, {}, false); + _any(result, iter); } Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) { diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index e1777184b90d..a357174d4b56 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -51,17 +51,16 @@ static inline Tensor restride_dim( return src.as_strided(replacement_shape, strides); } -inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self, +inline void _dimreduce_setup(const Tensor &result, const Tensor &self, int64_t dim) { IntArrayRef self_sizes = self.sizes(); std::vector result_sizes; result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end()); result_sizes[dim] = 1; result.resize_(result_sizes); - return result; } -inline bool _dimreduce_return_trivial(Tensor &result, const Tensor &self, +inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self, const Scalar& ident, int64_t dim, bool keepdim) { if (self.numel() == 1 && self.ndimension() == 0) { result.resize_({}); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3a1f75c588a8..a16d1ecddb98 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7360,17 +7360,28 @@ - func: all(Tensor self) -> Tensor device_check: NoCheck # TensorIterator + structured_delegate: all.full_out variants: method, function + +- func: all.full_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + structured: True dispatch: - CPU, CUDA: all + CPU, CUDA: all_full_out - func: any(Tensor self) -> Tensor device_check: NoCheck # TensorIterator + structured_delegate: any.full_out variants: method, function dispatch: - CPU, CUDA: any SparseCPU, SparseCUDA: any_sparse +- func: any.full_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck + structured: True + dispatch: + CPU, CUDA: any_full_out + - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True