Skip to content

Commit

Permalink
fix aminmax output resize issue when input is a zero dimension tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: 53a1bc6012b9a4ca290c3429fd6496751e561129
Pull Request resolved: #96171
  • Loading branch information
mingfeima committed Mar 15, 2023
1 parent 1cc32ae commit c08642b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 48 deletions.
34 changes: 1 addition & 33 deletions aten/src/ATen/functorch/BatchRulesReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,38 +316,6 @@ std::tuple<Tensor,optional<int64_t>> _log_softmax_backward_batch_rule(
return std::make_tuple(at::_log_softmax_backward_data(grad_output_, output_, dim, input_dtype), 0);
}

// aminmax has divergent behavior for 0-d tenosrs.
// reference: https://github.com/pytorch/pytorch/issues/64008
// TODO: Once the divergent behavior for 0-d scalar is fixed, we should use REDUCTION_BOXED_ARGS
std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> aminmax_batching_rule(
const Tensor &self, optional<int64_t> self_bdim, optional<int64_t> dim, bool keep_dim)
{
auto self_ = moveBatchDimToFront(self, self_bdim);
auto logical_rank = rankWithoutBatchDim(self_, self_bdim);
if (logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}

if (dim.has_value()) {
dim = maybe_wrap_dim(dim.value(), logical_rank) + 1;
} else {
// flatten the input except for batch-dim
auto bsize = self_.size(0);
self_ = self_.view({bsize, -1});
dim = 1;
}

Tensor min, max;
std::tie(min, max) = at::aminmax(self_, dim, keep_dim);

if (logical_rank == 0 && self_.device().is_cuda()) {
// behaviour diverges between cpu and cuda
min = min.squeeze(-1);
max = max.squeeze(-1);
}
return std::make_tuple(min, 0, max, 0);
}

std::tuple<Tensor,optional<int64_t>> searchsorted_batch_rule(
const Tensor& sorted_sequence,
optional<int64_t> sorted_sequence_bdim,
Expand Down Expand Up @@ -466,6 +434,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
REDUCTION_NO_KEEPDIM_ARG(_fft_c2c);
REDUCTION_WITH_KEEPDIM_ARG(amax);
REDUCTION_WITH_KEEPDIM_ARG(amin);
REDUCTION_WITH_KEEPDIM_ARG(aminmax);
m.impl("all", all_decomp);
REDUCTION_WITH_KEEPDIM_ARG(all.dim);
m.impl("any", any_decomp);
Expand Down Expand Up @@ -511,7 +480,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
REDUCTION_BOXED_ARGS(var_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
REDUCTION_NO_KEEPDIM_ARG(_log_softmax);
REDUCTION_BOXED_ARGS(rot90, 2, KEEPDIM_CASE_TRUE, -1);
VMAP_SUPPORT(aminmax, aminmax_batching_rule);
VMAP_SUPPORT(_log_softmax_backward_data, _log_softmax_backward_batch_rule);
VMAP_SUPPORT(_softmax_backward_data, _softmax_backward_batch_rule);
VMAP_SUPPORT(_is_all_true, _is_all_true_batch_rule);
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ static void aminmax_kernel(
"Expect min and max dtype ", self.scalar_type(),
" but got ", min_result.scalar_type(), " and ", max_result.scalar_type());

if (self.numel() == 1 && self.ndimension() == 0) {
min_result.resize_({});
max_result.resize_({});
min_result.fill_(self);
max_result.fill_(self);
return;
}

AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "aminmax_cpu", [&] {
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
scalar_t* min_result_data, scalar_t* max_result_data,
Expand Down
10 changes: 0 additions & 10 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3337,16 +3337,6 @@ def upsample_bicubic2d_vec(
def aminmax(self, *, dim=None, keepdim=False):
amin = torch.amin(self, dim=dim, keepdim=keepdim)
amax = torch.amax(self, dim=dim, keepdim=keepdim)
if (
keepdim
and dim is not None
and self.ndimension() == 0
and self.device.type == "cpu"
):
# the behavior of aminmax differs from amin/amax for 0D tensors on CPU
# https://github.com/pytorch/pytorch/issues/96042
amin = amin.expand([1])
amax = amax.expand([1])
return amin, amax


Expand Down
6 changes: 1 addition & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11503,11 +11503,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
decorators=(onlyNativeDeviceTypes,),
supports_autograd=False,
sample_inputs_func=sample_inputs_aminmax,
error_inputs_func=error_inputs_aminmax_amax_amin,
skips=(
# AssertionError: Resizing an out= argument with no elements threw a resize warning!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
)),
error_inputs_func=error_inputs_aminmax_amax_amin),
OpInfo('as_strided',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
Expand Down

0 comments on commit c08642b

Please sign in to comment.