Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port min kernel to structured kernels. #61450

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
122 changes: 56 additions & 66 deletions aten/src/ATen/native/TensorCompare.cpp
Expand Up @@ -60,27 +60,42 @@ TORCH_META_FUNC(isneginf) (const Tensor& self) {
build_unary_force_boolean_op(maybe_get_output(), self);
}

TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) {
void check_minmax_for_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
int64_t dim,
bool keepdim) {
TORCH_CHECK(
self.layout() == Layout::Strided,
"max(): only supports strided layout, got: ", self.layout());
name, ": only supports strided layout, got: ", self.layout());
TORCH_CHECK(
!self.is_complex(),
"max(): does not support complex input");
name, ": does not support complex input");

dim = maybe_wrap_dim(dim, self.dim());

DimVector sizes(self.sizes());
if (self.numel() == 0) {
sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, "max()");
sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, name);
} else {
sizes = get_reduction_shape(self, dim, keepdim);
}

set_output(0, sizes, self.options());
set_output(1, sizes, self.options().dtype(kLong));
namedinference::propagate_names_for_reduction(maybe_get_output(0), self, dim, keepdim);
namedinference::propagate_names_for_reduction(maybe_get_output(1), self, dim, keepdim);
meta.set_output(0, sizes, self.options());
meta.set_output(1, sizes, self.options().dtype(kLong));
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(0), self, dim, keepdim);
namedinference::propagate_names_for_reduction(
meta.maybe_get_output(1), self, dim, keepdim);
}

TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_minmax_for_meta(*this, "max()", self, dim, keepdim);
}

TORCH_META_FUNC2(min, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_minmax_for_meta(*this, "min()", self, dim, keepdim);
}

} // namespace meta
Expand Down Expand Up @@ -420,12 +435,14 @@ std::tuple<Tensor &,Tensor &> mode_out(const Tensor& self, int64_t dim, bool kee
}
}

TORCH_IMPL_FUNC(max_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
template <class Stub>
void minmax_out_impl(
const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices,
Stub& stub) {
NoNamesGuard guard;
if (self.numel() > 0) {
if (self.numel() == 1 && self.dim() == 0) {
Expand All @@ -434,11 +451,29 @@ TORCH_IMPL_FUNC(max_out)
} else {
auto& values_mut = const_cast<Tensor&>(values);
auto& indices_mut = const_cast<Tensor&>(indices);
max_stub(self.device().type(), values_mut, indices_mut, self, dim, keepdim);
stub(self.device().type(), values_mut, indices_mut, self, dim, keepdim);
}
}
}

TORCH_IMPL_FUNC(max_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
minmax_out_impl(self, dim, keepdim, values, indices, max_stub);
}

TORCH_IMPL_FUNC(min_out)
(const Tensor& self,
int64_t dim,
bool keepdim,
const Tensor& values,
const Tensor& indices) {
minmax_out_impl(self, dim, keepdim, values, indices, min_stub);
}

std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
// CUDA and CPU dispatch keys are handled by the structured kernel implementation.
TORCH_INTERNAL_ASSERT(self.is_quantized());
Expand All @@ -451,15 +486,13 @@ std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
}

std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
// CUDA and CPU dispatch keys are handled by the structured kernel implementation.
TORCH_INTERNAL_ASSERT(self.is_quantized());
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved
Tensor min_indices = at::empty({0}, self.options().dtype(kLong));
if (self.is_quantized()) {
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
at::native::min_out(self.int_repr(), dim, keepdim, min, min_indices);
return std::tuple<Tensor, Tensor>(at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
} else {
Tensor min = at::empty({0}, self.options());
return at::native::min_out(self, dim, keepdim, min, min_indices);
}
Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type())));
at::min_outf(self.int_repr(), dim, keepdim, min, min_indices);
return std::tuple<Tensor, Tensor>(
at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
}

static std::tuple<Tensor &, Tensor &> _aminmax_out_impl(Tensor& min, Tensor& max,
Expand Down Expand Up @@ -495,49 +528,6 @@ std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdi
return result;
}

static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
"min only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"min only supports strided layout, got: ", self.layout());
TORCH_CHECK(self.device() == min.device(),
"expected device ", self.device(), " but got ",
min.device(), " for min values output");
TORCH_CHECK(self.device() == min_indices.device(),
"expected device ", self.device(), " but got ",
min_indices.device(), " for indices output");
dim = maybe_wrap_dim(dim, self.dim());
if (self.numel() == 0) {
zero_numel_tensor_resize(min, min_indices, self, dim, keepdim, "min()");
return std::tie(min, min_indices);
}
else if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
TORCH_CHECK(!self.is_complex(), "min does not support complex inputs.");
AT_ASSERT(min.dim() == 0);
min_indices.resize_({}).fill_(0);
return std::forward_as_tuple(min, min_indices);
} else {
min_stub(self.device().type(), min, min_indices, self, dim, keepdim);
return std::tuple<Tensor &,Tensor &>{min, min_indices};
}
}

std::tuple<Tensor&, Tensor&> min_out(
const Tensor& self,
int64_t dim,
bool keepdim,
Tensor& min,
Tensor& min_indices) {
auto result = [&]() {
NoNamesGuard guard;
return min_out_impl(min, min_indices, self, dim, keepdim);
}();
namedinference::propagate_names_for_reduction(min, self, dim, keepdim);
namedinference::propagate_names_for_reduction(min_indices, self, dim, keepdim);
return result;
}

Tensor& clamp_out(const Tensor& self, const c10::optional<Scalar>& min, const c10::optional<Scalar>& max, Tensor& result) {
if (min && max) {
auto iter = TensorIterator::unary_op(result, self);
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2847,12 +2847,14 @@
- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)

- func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
structured_delegate: min.dim_min
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
CPU, CUDA, QuantizedCPU, QuantizedCUDA: min
QuantizedCPU, QuantizedCUDA: min

- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
structured: True
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: min_out
Expand Down
1 change: 1 addition & 0 deletions test/test_autograd.py
Expand Up @@ -7595,6 +7595,7 @@ def fn(v):
nnz = 0 if empty_nnz else 5
_test(sparse_size + dense_size, len(sparse_size), nnz, device)

@skipMeta
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved
@dtypes(torch.double, torch.cdouble)
def test_sparse_backward(self, device, dtype):
class FixedGradientFunction(Function):
Expand Down