Skip to content

Commit

Permalink
Support out overload on mps on "[inductor] Decompose boolean min/max …
Browse files Browse the repository at this point in the history
…into all/any"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
  • Loading branch information
peterbell10 committed Oct 4, 2023
2 parents c65dfec + 4d3b496 commit 588b841
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
14 changes: 14 additions & 0 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,20 @@ Tensor any_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdi
return allany_dims_default<false>(self, dim, keepdim);
}

Tensor& all_dims_out_default(
const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
auto tmp = self.all(dim, keepdim);
at::native::resize_output(result, tmp.sizes());
return result.copy_(tmp);
}

Tensor& any_dims_out_default(
const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
auto tmp = self.any(dim, keepdim);
at::native::resize_output(result, tmp.sizes());
return result.copy_(tmp);
}

TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
auto iter =
meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@
structured: True
dispatch:
CPU, CUDA: all_dims_out
CompositeExplicitAutograd: all_dims_out_default
cpp_no_default_args: ['dim']

- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
Expand Down Expand Up @@ -743,6 +744,7 @@
structured: True
dispatch:
CPU, CUDA: any_dims_out
CompositeExplicitAutograd: any_dims_out_default
cpp_no_default_args: ['dim']

- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
Expand Down

0 comments on commit 588b841

Please sign in to comment.