Skip to content

Commit

Permalink
Port min kernel to structured kernels.
Browse files Browse the repository at this point in the history
Tracking issue: #55070

ghstack-source-id: aa16ba15f10264fba0b9d3f6d5d857a61d821745
Pull Request resolved: #61450
  • Loading branch information
ysiraichi committed Jul 9, 2021
1 parent b3e76bf commit 2b4ecd6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 67 deletions.
122 changes: 56 additions & 66 deletions aten/src/ATen/native/TensorCompare.cpp
Expand Up @@ -60,27 +60,42 @@ TORCH_META_FUNC(isneginf) (const Tensor& self) {
build_unary_force_boolean_op(maybe_get_output(), self);
}

TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) {
void check_minmax_for_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
int64_t dim,
bool keepdim) {
TORCH_CHECK(
self.layout() == Layout::Strided,
"max(): only supports strided layout, got: ", self.layout());
name, ": only supports strided layout, got: ", self.layout());
TORCH_CHECK(
!self.is_complex(),
"max(): does not support complex input");
name, ": does not support complex input");

dim = maybe_wrap_dim(dim, self.dim());

DimVector sizes(self.sizes());
if (self.numel() == 0) {
sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, "max()");
sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, name);
} else {
sizes = get_reduction_shape(self, dim, keepdim);
}

set_output(0, sizes, self.options());
set_output(1, sizes, self.options().dtype(kLong));
namedinference::propagate_names_for_reduction(maybe_get_output(0), self, dim, keepdim);
namedinference::propagate_names_for_reduction(maybe_get_output(1), self, dim, keepdim);
meta.set_output(0, sizes, self.options());
meta.set_output(1, sizes, self.options().dtype(kLong));
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(0), self, dim, keepdim);
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(1), self, dim, keepdim);
}

TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_minmax_for_meta(*this, "max()", self, dim, keepdim);
}

TORCH_META_FUNC2(min, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_minmax_for_meta(*this, "min()", self, dim, keepdim);
}

} // namespace meta
Expand Down Expand Up @@ -420,12 +435,14 @@ std::tuple<Tensor &,Tensor &> mode_out(const Tensor& self, int64_t dim, bool kee
}
}

TORCH_IMPL_FUNC(max_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
template <class Stub>
void minmax_out_impl(
const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices,
Stub& stub) {
NoNamesGuard guard;
if (self.numel() > 0) {
if (self.numel() == 1 && self.dim() == 0) {
Expand All @@ -434,11 +451,29 @@ TORCH_IMPL_FUNC(max_out)
} else {
auto& values_mut = const_cast<Tensor&>(values);
auto& indices_mut = const_cast<Tensor&>(indices);
max_stub(self.device().type(), values_mut, indices_mut, self, dim, keepdim);
stub(self.device().type(), values_mut, indices_mut, self, dim, keepdim);
}
}
}

TORCH_IMPL_FUNC(max_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
minmax_out_impl(self, dim, keepdim, values, indices, max_stub);
}

TORCH_IMPL_FUNC(min_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
minmax_out_impl(self, dim, keepdim, values, indices, min_stub);
}

std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
// CUDA and CPU dispatch keys are handled by the structured kernel implementation.
TORCH_INTERNAL_ASSERT(self.is_quantized());
Expand All @@ -451,15 +486,13 @@ std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
}

std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
// CUDA and CPU dispatch keys are handled by the structured kernel implementation.
TORCH_INTERNAL_ASSERT(self.is_quantized());
Tensor min_indices = at::empty({0}, self.options().dtype(kLong));
if (self.is_quantized()) {
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
at::native::min_out(self.int_repr(), dim, keepdim, min, min_indices);
return std::tuple<Tensor, Tensor>(at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
} else {
Tensor min = at::empty({0}, self.options());
return at::native::min_out(self, dim, keepdim, min, min_indices);
}
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
at::min_outf(self.int_repr(), dim, keepdim, min, min_indices);
return std::tuple<Tensor, Tensor>(
at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
}

static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max,
Expand Down Expand Up @@ -495,49 +528,6 @@ std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdi
return result;
}

static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
"min only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"min only supports strided layout, got: ", self.layout());
TORCH_CHECK(self.device() == min.device(),
"expected device ", self.device(), " but got ",
min.device(), " for min values output");
TORCH_CHECK(self.device() == min_indices.device(),
"expected device ", self.device(), " but got ",
min_indices.device(), " for indices output");
dim = maybe_wrap_dim(dim, self.dim());
if (self.numel() == 0) {
zero_numel_tensor_resize(min, min_indices, self, dim, keepdim, "min()");
return std::tie(min, min_indices);
}
else if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
TORCH_CHECK(!self.is_complex(), "min does not support complex inputs.");
AT_ASSERT(min.dim() == 0);
min_indices.resize_({}).fill_(0);
return std::forward_as_tuple(min, min_indices);
} else {
min_stub(self.device().type(), min, min_indices, self, dim, keepdim);
return std::tuple<Tensor &,Tensor &>{min, min_indices};
}
}

std::tuple<Tensor&, Tensor&> min_out(
const Tensor& self,
int64_t dim,
bool keepdim,
Tensor& min,
Tensor& min_indices) {
auto result = [&]() {
NoNamesGuard guard;
return min_out_impl(min, min_indices, self, dim, keepdim);
}();
namedinference::propagate_names_for_reduction(min, self, dim, keepdim);
namedinference::propagate_names_for_reduction(min_indices, self, dim, keepdim);
return result;
}

Tensor& clamp_out(const Tensor& self, const c10::optional<Scalar>& min, const c10::optional<Scalar>& max, Tensor& result) {
if (min && max) {
auto iter = TensorIterator::unary_op(result, self);
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2847,12 +2847,14 @@
- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)

- func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
structured_delegate: min.dim_min
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
CPU, CUDA, QuantizedCPU, QuantizedCUDA: min
QuantizedCPU, QuantizedCUDA: min

- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
structured: True
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: min_out
Expand Down

0 comments on commit 2b4ecd6

Please sign in to comment.