Skip to content

Commit

Permalink
fix aminmax output resize issue when input is a zero dimension tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: f6e5c57e1d08759176dc1ef596d3cb9ceec74dba
Pull Request resolved: #96171
  • Loading branch information
mingfeima committed Mar 9, 2023
1 parent fe05266 commit fbbc81b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 15 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ static void aminmax_kernel(
"Expect min and max dtype ", self.scalar_type(),
" but got ", min_result.scalar_type(), " and ", max_result.scalar_type());

if (self.numel() == 1 && self.ndimension() == 0) {
min_result.resize_({});
max_result.resize_({});
min_result.fill_(self);
max_result.fill_(self);
return;
}

AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "aminmax_cpu", [&] {
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
scalar_t* min_result_data, scalar_t* max_result_data,
Expand Down
1 change: 1 addition & 0 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3595,6 +3595,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04), torch.complex64: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
xfail('as_strided', 'partial_views'),
skip('aminmax'),
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('complex'),
xfail('copysign'),
Expand Down
10 changes: 0 additions & 10 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3337,16 +3337,6 @@ def upsample_bicubic2d_vec(
def aminmax(self, *, dim=None, keepdim=False):
amin = torch.amin(self, dim=dim, keepdim=keepdim)
amax = torch.amax(self, dim=dim, keepdim=keepdim)
if (
keepdim
and dim is not None
and self.ndimension() == 0
and self.device.type == "cpu"
):
# the behavior of aminmax differs from amin/amax for 0D tensors on CPU
# https://github.com/pytorch/pytorch/issues/96042
amin = amin.expand([1])
amax = amax.expand([1])
return amin, amax


Expand Down
6 changes: 1 addition & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11503,11 +11503,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
decorators=(onlyNativeDeviceTypes,),
supports_autograd=False,
sample_inputs_func=sample_inputs_aminmax,
error_inputs_func=error_inputs_aminmax_amax_amin,
skips=(
# AssertionError: Resizing an out= argument with no elements threw a resize warning!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
)),
error_inputs_func=error_inputs_aminmax_amax_amin),
OpInfo('as_strided',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
Expand Down

0 comments on commit fbbc81b

Please sign in to comment.