From 4803eaf502fb9510a729bd1e1d3bb52744a84886 Mon Sep 17 00:00:00 2001 From: kiyosora Date: Wed, 20 Jan 2021 06:42:21 -0800 Subject: [PATCH] Implement NumPy-like function torch.fmax() & torch.fmin() (#49312) Summary: - Implementing the NumPy-like function`torch.fmax()` and `torch.fmin()` recommended in https://github.com/pytorch/pytorch/issues/48440 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49312 Reviewed By: izdeby Differential Revision: D25887246 Pulled By: heitorschueroff fbshipit-source-id: d762eeff8b328bfcbe7d48b7ee9d2da72c249691 --- aten/src/ATen/core/aten_interned_strings.h | 2 + aten/src/ATen/native/BinaryOps.cpp | 36 ++++++++++++ aten/src/ATen/native/BinaryOps.h | 2 + aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 28 +++++++++ .../native/cuda/MaxMinElementwiseKernel.cu | 26 +++++++++ aten/src/ATen/native/native_functions.yaml | 22 +++++++ docs/source/tensors.rst | 2 + docs/source/torch.rst | 2 + test/test_binary_ufuncs.py | 35 ++++++----- tools/autograd/derivatives.yaml | 8 +++ torch/_tensor_docs.py | 14 +++++ torch/_torch_docs.py | 58 +++++++++++++++++++ torch/overrides.py | 2 + .../_internal/common_methods_invocations.py | 2 + 14 files changed, 226 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 518e74b95d54..24ce8a512d58 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -331,6 +331,8 @@ _(aten, flipud) \ _(aten, floor) \ _(aten, fmod) \ _(aten, fmod_) \ +_(aten, fmax) \ +_(aten, fmin) \ _(aten, frac) \ _(aten, fractional_max_pool2d) \ _(aten, fractional_max_pool2d_backward) \ diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 47f75b392f9a..2589835d3c9f 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -51,6 +51,8 @@ DEFINE_DISPATCH(logit_backward_stub); DEFINE_DISPATCH(tanh_backward_stub); DEFINE_DISPATCH(maximum_stub); DEFINE_DISPATCH(minimum_stub); +DEFINE_DISPATCH(fmax_stub); +DEFINE_DISPATCH(fmin_stub); DEFINE_DISPATCH(fmod_stub); DEFINE_DISPATCH(logaddexp_stub); DEFINE_DISPATCH(logaddexp2_stub); @@ -847,6 +849,23 @@ Tensor max(const Tensor& self, const Tensor& other) { return at::maximum(self, other); } +Tensor& fmax_out(const Tensor& self, const Tensor& other, Tensor& result) { + TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmax not implemented for complex tensors."); + + auto iter = TensorIterator::binary_op(result, self, other); + fmax_stub(iter.device_type(), iter); + return result; +} + +Tensor fmax(const Tensor& self, const Tensor& other) { + TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmax not implemented for complex tensors."); + + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + fmax_stub(iter.device_type(), iter); + return iter.output(); +} + Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_op(result, self, other); minimum_stub(iter.device_type(), iter); @@ -869,6 +888,23 @@ Tensor min(const Tensor& self, const Tensor& other) { return at::minimum(self, other); } +Tensor& fmin_out(const Tensor& self, const Tensor& other, Tensor& result) { + TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmin not implemented for complex tensors."); + + auto iter = TensorIterator::binary_op(result, self, other); + fmin_stub(iter.device_type(), iter); + return result; +} + +Tensor fmin(const Tensor& self, const Tensor& other) { + TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmin not implemented for complex tensors."); + + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + fmin_stub(iter.device_type(), iter); + return iter.output(); +} + Tensor floor_divide(const Tensor& self, Scalar other) { return at::floor_divide(self, wrapped_scalar_tensor(other)); } diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 191611875f08..56b1dfb6e7ee 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -58,6 +58,8 @@ DECLARE_DISPATCH(binary_fn, max_elementwise_stub); DECLARE_DISPATCH(binary_fn, min_elementwise_stub); DECLARE_DISPATCH(binary_fn, maximum_stub); DECLARE_DISPATCH(binary_fn, minimum_stub); +DECLARE_DISPATCH(binary_fn, fmax_stub); +DECLARE_DISPATCH(binary_fn, fmin_stub); DECLARE_DISPATCH(binary_fn_beta, smooth_l1_stub); DECLARE_DISPATCH(binary_fn, sigmoid_backward_stub); DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub); diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 12301dc4a38e..9cadfd2f29ff 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -505,6 +505,32 @@ void minimum_kernel(TensorIterator& iter) { } } +void fmax_kernel(TensorIterator& iter) { + if (isFloatingType(iter.common_dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmax_cpu", [&]() { + cpu_kernel(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return std::fmax(a, b); + }); + }); + } else { + maximum_kernel(iter); + } +} + +void fmin_kernel(TensorIterator& iter) { + if (isFloatingType(iter.common_dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmin_cpu", [&]() { + cpu_kernel(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return std::fmin(a, b); + }); + }); + } else { + minimum_kernel(iter); + } +} + void smooth_l1_kernel(TensorIterator& iter, double beta) { AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, iter.dtype(), "smooth_l1_cpu", [&]() { @@ -854,6 +880,8 @@ REGISTER_DISPATCH(eq_stub, &eq_kernel); REGISTER_DISPATCH(ne_stub, &ne_kernel); REGISTER_DISPATCH(maximum_stub, &maximum_kernel); REGISTER_DISPATCH(minimum_stub, &minimum_kernel); +REGISTER_DISPATCH(fmax_stub, &fmax_kernel); +REGISTER_DISPATCH(fmin_stub, &fmin_kernel); REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel); REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel); REGISTER_DISPATCH(logit_backward_stub, &logit_backward_kernel); diff --git a/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu b/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu index 6142e427ffd1..29e7aa03dfe1 100644 --- a/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu +++ b/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu @@ -62,7 +62,33 @@ void minimum_kernel_cuda(TensorIterator& iter) { } } +void fmax_kernel_cuda(TensorIterator& iter) { + if (isFloatingType(iter.common_dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmax_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return ::fmax(a, b); + }); + }); + } else { + maximum_kernel_cuda(iter); + } +} + +void fmin_kernel_cuda(TensorIterator& iter) { + if (isFloatingType(iter.common_dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmin_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return ::fmin(a, b); + }); + }); + } else { + minimum_kernel_cuda(iter); + } +} + REGISTER_DISPATCH(maximum_stub, &maximum_kernel_cuda); REGISTER_DISPATCH(minimum_stub, &minimum_kernel_cuda); +REGISTER_DISPATCH(fmax_stub, &fmax_kernel_cuda); +REGISTER_DISPATCH(fmin_stub, &fmin_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 91b4dd7fcd8b..d2512127da3e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6241,12 +6241,34 @@ CPU, CUDA: min QuantizedCPU: min_quantized_cpu +- func: fmin(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: method, function + dispatch: + CPU, CUDA: fmin + +- func: fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: full + dispatch: + CPU, CUDA: fmin_out + - func: max(Tensor self) -> Tensor variants: method, function dispatch: CPU, CUDA: max QuantizedCPU: max_quantized_cpu +- func: fmax(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: method, function + dispatch: + CPU, CUDA: fmax + +- func: fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: full + dispatch: + CPU, CUDA: fmax_out + - func: maximum(Tensor self, Tensor other) -> Tensor variants: method, function dispatch: diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 1baf34dd955e..55cc2751a891 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -288,6 +288,8 @@ view of a storage and defines numeric operations on it. .. automethod:: diagflat .. automethod:: diagonal .. automethod:: fill_diagonal_ + .. automethod:: fmax + .. automethod:: fmin .. automethod:: digamma .. automethod:: digamma_ .. automethod:: dim diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 922e1434bae1..6a18c79282ea 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -416,6 +416,8 @@ Comparison Ops less maximum minimum + fmax + fmin ne not_equal sort diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 89e3c58be498..605d0c5f39fe 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -979,13 +979,14 @@ def test_binary_ops_with_scalars(self, device): def test_maximum_minimum_type_promotion(self, device, dtypes): a = torch.tensor((0, 1), device=device, dtype=dtypes[0]) b = torch.tensor((1, 0), device=device, dtype=dtypes[1]) - for op in (torch.maximum, torch.max, torch.minimum, torch.min): + for op in (torch.maximum, torch.max, torch.fmax, torch.minimum, torch.min, torch.fmin): result = op(a, b) self.assertEqual(result.dtype, torch.result_type(a, b)) @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) def test_maximum_minimum_int_and_bool(self, device, dtype): - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) + ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), + (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) rng = np.random.default_rng() a_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) b_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) @@ -994,21 +995,24 @@ def test_maximum_minimum_int_and_bool(self, device, dtype): a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) tensor_result = torch_op(a_tensor, b_tensor) - alias_result = alias(a_tensor, b_tensor) out = torch.empty_like(a_tensor) torch_op(a_tensor, b_tensor, out=out) numpy_result = numpy_op(a_np, b_np) - self.assertEqual(alias_result, tensor_result) + if alias is not None: + alias_result = alias(a_tensor, b_tensor) + self.assertEqual(alias_result, tensor_result) + self.assertEqual(tensor_result, numpy_result) self.assertEqual(out, numpy_result) @precisionOverride({torch.bfloat16: 1e-2}) @dtypes(*(torch.testing.get_all_fp_dtypes())) def test_maximum_minimum_float(self, device, dtype): - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) + ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), + (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) if dtype == torch.bfloat16: a_np = np.random.randn(10).astype(np.float64) @@ -1023,11 +1027,13 @@ def test_maximum_minimum_float(self, device, dtype): a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) tensor_result = torch_op(a_tensor, b_tensor) - alias_result = alias(a_tensor, b_tensor) out = torch.empty_like(a_tensor) torch_op(a_tensor, b_tensor, out=out) - self.assertEqual(alias_result, tensor_result) + if alias is not None: + alias_result = alias(a_tensor, b_tensor) + self.assertEqual(alias_result, tensor_result) + self.assertEqual(tensor_result, numpy_result) self.assertEqual(out, numpy_result) @@ -1035,9 +1041,10 @@ def test_maximum_minimum_float(self, device, dtype): def test_maximum_minimum_float_nan_and_inf(self, device, dtype): # np.maximum and np.minimum functions compare input arrays element-wisely. # if one of the elements being compared is a NaN, then that element is returned. - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) - a_vals = (float('inf'), -float('inf'), float('nan'), float('nan')) - b_vals = (-float('inf'), float('inf'), float('inf'), float('nan')) + ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), + (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) + a_vals = (float('inf'), -float('inf'), float('nan'), float('inf'), float('nan'), float('nan'), 1, float('nan')) + b_vals = (-float('inf'), float('inf'), float('inf'), float('nan'), float('nan'), 0, float('nan'), -5) if dtype == torch.bfloat16: a_np = np.array(a_vals, dtype=np.float64) b_np = np.array(b_vals, dtype=np.float64) @@ -1051,12 +1058,14 @@ def test_maximum_minimum_float_nan_and_inf(self, device, dtype): a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) tensor_result = torch_op(a_tensor, b_tensor) - alias_result = alias(a_tensor, b_tensor) out = torch.empty_like(a_tensor) torch_op(a_tensor, b_tensor, out=out) - self.assertEqual(alias_result, tensor_result) + if alias is not None: + alias_result = alias(a_tensor, b_tensor) + self.assertEqual(alias_result, tensor_result) + if dtype == torch.bfloat16: self.assertEqual(tensor_result, numpy_result, exact_dtype=False) self.assertEqual(out, numpy_result, exact_dtype=False) @@ -1066,7 +1075,7 @@ def test_maximum_minimum_float_nan_and_inf(self, device, dtype): @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) def test_maximum_minimum_complex(self, device, dtypes): - for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min): + for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min, torch.fmax, torch.fmin): with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'): torch_op(torch.ones(1, device=device, dtype=dtypes[0]), torch.ones(1, device=device, dtype=dtypes[1])) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a8dd3b18fe54..d37480ead923 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -717,6 +717,10 @@ self: grad.clone().masked_fill_(self <= other, 0) other: grad.clone().masked_fill_(self > other, 0) +- name: fmax(Tensor self, Tensor other) -> Tensor + self: grad.clone().masked_fill_((self >= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.clone().masked_fill_((self >= other).logical_or_(other.isnan()), 0) + - name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor self: grad.expand(self.sizes()).to(self.scalar_type()) / self.numel() @@ -759,6 +763,10 @@ self: grad.clone().masked_fill_(self >= other, 0) other: grad.clone().masked_fill_(self < other, 0) +- name: fmin(Tensor self, Tensor other) -> Tensor + self: grad.clone().masked_fill_((self <= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.clone().masked_fill_((self <= other).logical_or_(other.isnan()), 0) + - name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 83bc04113672..d5335563dc77 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2273,6 +2273,13 @@ def callable(a, b) -> number See :func:`torch.maximum` """) +add_docstr_all('fmax', + r""" +fmax(other) -> Tensor + +See :func:`torch.fmax` +""") + add_docstr_all('argmax', r""" argmax(dim=None, keepdim=False) -> LongTensor @@ -2322,6 +2329,13 @@ def callable(a, b) -> number See :func:`torch.minimum` """) +add_docstr_all('fmin', + r""" +fmin(other) -> Tensor + +See :func:`torch.fmin` +""") + add_docstr_all('argmin', r""" argmin(dim=None, keepdim=False) -> LongTensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index d25a8d5b38cf..62a52be2efdb 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5105,6 +5105,35 @@ def merge_dicts(*dicts): tensor([3, 2, 4]) """.format(**common_args)) +add_docstr(torch.fmax, r""" +fmax(input, other, *, out=None) -> Tensor + +Computes the element-wise maximum of :attr:`input` and :attr:`other`. + +This is like :func:`torch.maximum` except it handles NaNs differently: +if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum. +Only if both elements are NaN is NaN propagated. + +This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and floating-point inputs. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([9.7, float('nan'), 3.1, float('nan')]) + >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) + >>> torch.fmax(a, b) + tensor([9.7000, 0.5000, 3.1000, nan]) +""".format(**common_args)) + add_docstr(torch.amax, r""" amax(input, dim, keepdim=False, *, out=None) -> Tensor @@ -5551,6 +5580,35 @@ def merge_dicts(*dicts): tensor([1, 0, -1]) """.format(**common_args)) +add_docstr(torch.fmin, r""" +fmin(input, other, *, out=None) -> Tensor + +Computes the element-wise minimum of :attr:`input` and :attr:`other`. + +This is like :func:`torch.minimum` except it handles NaNs differently: +if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum. +Only if both elements are NaN is NaN propagated. + +This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function. + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer and floating-point inputs. + +Args: + {input} + other (Tensor): the second input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([2.2, float('nan'), 2.1, float('nan')]) + >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) + >>> torch.fmin(a, b) + tensor([-9.3000, 0.1000, 2.1000, nan]) +""".format(**common_args)) + add_docstr(torch.amin, r""" amin(input, dim, keepdim=False, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 3f8b587eedf9..452abc8e8bef 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -527,6 +527,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.matrix_exp: lambda input: -1, torch.max: lambda input, out=None: -1, torch.maximum: lambda input, other, out=None: -1, + torch.fmax: lambda input, other, out=None: -1, torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, @@ -538,6 +539,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.meshgrid: lambda *tensors, **kwargs: -1, torch.min: lambda input, out=None: -1, torch.minimum: lambda input, other, out=None: -1, + torch.fmin: lambda input, other, out=None: -1, torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1), torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3683e2168901..3f21c0723968 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2552,6 +2552,8 @@ def method_tests(): 'scalar_broadcast_rhs'), ('maximum', (S, S), ((S, S),)), ('minimum', (S, S), ((S, S),)), + ('fmax', (S, S), ((S, S),)), + ('fmin', (S, S), ((S, S),)), ('resize_', (S, S, S), (torch.Size([S * S, S])), 'fewer_dims'), ('resize_', (), (dont_convert(()),), 'scalar'), ('resize_', (), (torch.Size([1, 1, 1])), 'scalar_to_dims'),