Skip to content

Commit

Permalink
Port min kernel to structured kernels.
Browse files Browse the repository at this point in the history
Tracking issue: #55070

ghstack-source-id: 3334511e5378826cebaed6bffdd00bd0e71d5058
Pull Request resolved: #61450
  • Loading branch information
ysiraichi committed Jul 14, 2021
1 parent 997b4e6 commit 6ed293f
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 80 deletions.
122 changes: 55 additions & 67 deletions aten/src/ATen/native/TensorCompare.cpp
Expand Up @@ -60,27 +60,42 @@ TORCH_META_FUNC(isneginf) (const Tensor& self) {
build_unary_force_boolean_op(maybe_get_output(), self);
}

TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) {
void check_minmax_for_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
int64_t dim,
bool keepdim) {
TORCH_CHECK(
self.layout() == Layout::Strided,
"max(): only supports strided layout, got: ", self.layout());
name, ": only supports strided layout, got: ", self.layout());
TORCH_CHECK(
!self.is_complex(),
"max(): does not support complex input");
name, ": does not support complex input");

dim = maybe_wrap_dim(dim, self.dim());

DimVector sizes(self.sizes());
if (self.numel() == 0) {
sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, "max()");
sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, name);
} else {
sizes = get_reduction_shape(self, dim, keepdim);
}

set_output(0, sizes, self.options());
set_output(1, sizes, self.options().dtype(kLong));
namedinference::propagate_names_for_reduction(maybe_get_output(0), self, dim, keepdim);
namedinference::propagate_names_for_reduction(maybe_get_output(1), self, dim, keepdim);
meta.set_output(0, sizes, self.options());
meta.set_output(1, sizes, self.options().dtype(kLong));
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(0), self, dim, keepdim);
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(1), self, dim, keepdim);
}

TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_minmax_for_meta(*this, "max()", self, dim, keepdim);
}

TORCH_META_FUNC2(min, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_minmax_for_meta(*this, "min()", self, dim, keepdim);
}

} // namespace meta
Expand Down Expand Up @@ -420,23 +435,43 @@ std::tuple<Tensor &,Tensor &> mode_out(const Tensor& self, int64_t dim, bool kee
}
}

TORCH_IMPL_FUNC(max_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
template <class Stub>
void minmax_out_impl(
const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices,
Stub& stub) {
NoNamesGuard guard;
if (self.numel() > 0) {
if (self.numel() == 1 && self.dim() == 0) {
values.fill_(self);
indices.fill_(0);
} else {
max_stub(self.device().type(), values, indices, self, dim, keepdim);
stub(self.device().type(), values, indices, self, dim, keepdim);
}
}
}

TORCH_IMPL_FUNC(max_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
minmax_out_impl(self, dim, keepdim, values, indices, max_stub);
}

TORCH_IMPL_FUNC(min_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
minmax_out_impl(self, dim, keepdim, values, indices, min_stub);
}

std::tuple<Tensor, Tensor> qmax(const Tensor& self, int64_t dim, bool keepdim) {
Tensor max_indices = at::empty({0}, self.options().dtype(kLong));
Tensor max = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
Expand All @@ -446,16 +481,12 @@ std::tuple<Tensor, Tensor> qmax(const Tensor& self, int64_t dim, bool keepdim) {
at::_make_per_tensor_quantized_tensor(max, self.q_scale(), self.q_zero_point()), max_indices);
}

std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
std::tuple<Tensor, Tensor> qmin(const Tensor& self, int64_t dim, bool keepdim) {
Tensor min_indices = at::empty({0}, self.options().dtype(kLong));
if (self.is_quantized()) {
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
at::native::min_out(self.int_repr(), dim, keepdim, min, min_indices);
return std::tuple<Tensor, Tensor>(at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
} else {
Tensor min = at::empty({0}, self.options());
return at::native::min_out(self, dim, keepdim, min, min_indices);
}
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
at::min_outf(self.int_repr(), dim, keepdim, min, min_indices);
return std::tuple<Tensor, Tensor>(
at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
}

static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max,
Expand Down Expand Up @@ -491,49 +522,6 @@ std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdi
return result;
}

static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
"min only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"min only supports strided layout, got: ", self.layout());
TORCH_CHECK(self.device() == min.device(),
"expected device ", self.device(), " but got ",
min.device(), " for min values output");
TORCH_CHECK(self.device() == min_indices.device(),
"expected device ", self.device(), " but got ",
min_indices.device(), " for indices output");
dim = maybe_wrap_dim(dim, self.dim());
if (self.numel() == 0) {
zero_numel_tensor_resize(min, min_indices, self, dim, keepdim, "min()");
return std::tie(min, min_indices);
}
else if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
TORCH_CHECK(!self.is_complex(), "min does not support complex inputs.");
AT_ASSERT(min.dim() == 0);
min_indices.resize_({}).fill_(0);
return std::forward_as_tuple(min, min_indices);
} else {
min_stub(self.device().type(), min, min_indices, self, dim, keepdim);
return std::tuple<Tensor &,Tensor &>{min, min_indices};
}
}

std::tuple<Tensor&, Tensor&> min_out(
const Tensor& self,
int64_t dim,
bool keepdim,
Tensor& min,
Tensor& min_indices) {
auto result = [&]() {
NoNamesGuard guard;
return min_out_impl(min, min_indices, self, dim, keepdim);
}();
namedinference::propagate_names_for_reduction(min, self, dim, keepdim);
namedinference::propagate_names_for_reduction(min_indices, self, dim, keepdim);
return result;
}

Tensor& clamp_out(const Tensor& self, const c10::optional<Scalar>& min, const c10::optional<Scalar>& max, Tensor& result) {
if (min && max) {
auto iter = TensorIterator::unary_op(result, self);
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/TensorCompare.h
Expand Up @@ -9,11 +9,11 @@ namespace at { namespace native {

using reduce_minmax_fn =
void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
using reduce_max_fn =
using structured_reduce_minmax_fn =
void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);

DECLARE_DISPATCH(reduce_max_fn, max_stub);
DECLARE_DISPATCH(reduce_minmax_fn, min_stub);
DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub);
DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub);
DECLARE_DISPATCH(reduce_minmax_fn, _aminmax_stub);

using where_fn = void (*)(TensorIterator &, ScalarType);
Expand Down
7 changes: 2 additions & 5 deletions aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Expand Up @@ -88,17 +88,14 @@ static inline void compare_base_kernel(const Tensor& result1, const Tensor& resu
}

static void min_kernel_impl(
Tensor& result,
Tensor& indice,
const Tensor& result,
const Tensor& indice,
const Tensor& self,
int64_t dim,
bool keepdim) {
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);

TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong,
"Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type());

AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
compare_base_kernel<scalar_t>(result, indice, self, wrap_dim, keepdim, [&] (
scalar_t* result_data, int64_t* indice_data,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu
Expand Up @@ -97,8 +97,8 @@ void argmin_kernel_cuda(TensorIterator& iter) {
}
}

static void min_kernel_impl(Tensor& result, Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) {
at::TensorIterator iter = make_reduction("min", result, indice, self, dim, keepdim, self.scalar_type(), kLong);
static void min_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) {
auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong);
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(2), "min_cuda", [&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2855,12 +2855,14 @@
- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)

- func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
structured_delegate: min.dim_min
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
CPU, CUDA, QuantizedCPU, QuantizedCUDA: min
QuantizedCPU, QuantizedCUDA: qmin

- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
structured: True
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: min_out
Expand Down
1 change: 1 addition & 0 deletions test/test_autograd.py
Expand Up @@ -7595,6 +7595,7 @@ def fn(v):
nnz = 0 if empty_nnz else 5
_test(sparse_size + dense_size, len(sparse_size), nnz, device)

@skipMeta
@dtypes(torch.double, torch.cdouble)
def test_sparse_backward(self, device, dtype):
class FixedGradientFunction(Function):
Expand Down

0 comments on commit 6ed293f

Please sign in to comment.