Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix aminmax output resize issue when input is a zero dimension tensor #96171

Closed
wants to merge 7 commits into from
34 changes: 1 addition & 33 deletions aten/src/ATen/functorch/BatchRulesReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,38 +316,6 @@ std::tuple<Tensor,optional<int64_t>> _log_softmax_backward_batch_rule(
return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0);
}

// aminmax has divergent behavior for 0-d tenosrs.
// reference: https://github.com/pytorch/pytorch/issues/64008
// TODO: Once the divergent behavior for 0-d scalar is fixed, we should use REDUCTION_BOXED_ARGS
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> aminmax_batching_rule(
const Tensor &self, optional<int64_t> self_bdim, optional<int64_t> dim, bool keep_dim)
{
auto self_ = moveBatchDimToFront(self, self_bdim);
auto logical_rank = rankWithoutBatchDim(self_, self_bdim);
if (logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}

if (dim.has_value()) {
dim = maybe_wrap_dim(dim.value(), logical_rank) + 1;
} else {
// flatten the input except for batch-dim
auto bsize = self_.size(0);
self_ = self_.view({bsize, -1});
dim = 1;
}

Tensor min, max;
std::tie(min, max) = at::aminmax(self_, dim, keep_dim);

if (logical_rank == 0 && self_.device().is_cuda()) {
// behaviour diverges between cpu and cuda
min = min.squeeze(-1);
max = max.squeeze(-1);
}
return std::make_tuple(min, 0, max, 0);
}

std::tuple<Tensor,optional<int64_t>> searchsorted_batch_rule(
const Tensor& sorted_sequence,
optional<int64_t> sorted_sequence_bdim,
Expand Down Expand Up @@ -466,6 +434,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
REDUCTION_NO_KEEPDIM_ARG(_fft_c2c);
REDUCTION_WITH_KEEPDIM_ARG(amax);
REDUCTION_WITH_KEEPDIM_ARG(amin);
REDUCTION_WITH_KEEPDIM_ARG(aminmax);
m.impl("all", all_decomp);
REDUCTION_WITH_KEEPDIM_ARG(all.dim);
m.impl("any", any_decomp);
Expand Down Expand Up @@ -511,7 +480,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
REDUCTION_BOXED_ARGS(var_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
REDUCTION_NO_KEEPDIM_ARG(_log_softmax);
REDUCTION_BOXED_ARGS(rot90, 2, KEEPDIM_CASE_TRUE, -1);
VMAP_SUPPORT(aminmax, aminmax_batching_rule);
VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule);
VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule);
VMAP_SUPPORT(_is_all_true, _is_all_true_batch_rule);
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ static void aminmax_kernel(
"Expect min and max dtype ", self.scalar_type(),
" but got ", min_result.scalar_type(), " and ", max_result.scalar_type());

if (self.numel() == 1 && self.ndimension() == 0) {
min_result.resize_({});
max_result.resize_({});
min_result.fill_(self);
max_result.fill_(self);
return;
}

AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "aminmax_cpu", [&] {
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
scalar_t* min_result_data, scalar_t* max_result_data,
Expand Down
10 changes: 0 additions & 10 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3337,16 +3337,6 @@ def upsample_bicubic2d_vec(
def aminmax(self, *, dim=None, keepdim=False):
amin = torch.amin(self, dim=dim, keepdim=keepdim)
amax = torch.amax(self, dim=dim, keepdim=keepdim)
if (
keepdim
and dim is not None
and self.ndimension() == 0
and self.device.type == "cpu"
):
# the behavior of aminmax differs from amin/amax for 0D tensors on CPU
# https://github.com/pytorch/pytorch/issues/96042
amin = amin.expand([1])
amax = amax.expand([1])
return amin, amax


Expand Down
6 changes: 1 addition & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11503,11 +11503,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
decorators=(onlyNativeDeviceTypes,),
supports_autograd=False,
sample_inputs_func=sample_inputs_aminmax,
error_inputs_func=error_inputs_aminmax_amax_amin,
skips=(
# AssertionError: Resizing an out= argument with no elements threw a resize warning!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
)),
error_inputs_func=error_inputs_aminmax_amax_amin),
OpInfo('as_strided',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
Expand Down