Skip to content

Commit

Permalink
Implement NumPy-like function torch.fmax() & torch.fmin() (#49312)
Browse files Browse the repository at this point in the history
Summary:
- Implementing the NumPy-like function`torch.fmax()` and `torch.fmin()` recommended in #48440

Pull Request resolved: #49312

Reviewed By: izdeby

Differential Revision: D25887246

Pulled By: heitorschueroff

fbshipit-source-id: d762eeff8b328bfcbe7d48b7ee9d2da72c249691
  • Loading branch information
Kiyosora authored and facebook-github-bot committed Jan 20, 2021
1 parent 2ace4fc commit 4803eaf
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 13 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -331,6 +331,8 @@ _(aten, flipud) \
_(aten, floor) \
_(aten, fmod) \
_(aten, fmod_) \
_(aten, fmax) \
_(aten, fmin) \
_(aten, frac) \
_(aten, fractional_max_pool2d) \
_(aten, fractional_max_pool2d_backward) \
Expand Down
36 changes: 36 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -51,6 +51,8 @@ DEFINE_DISPATCH(logit_backward_stub);
DEFINE_DISPATCH(tanh_backward_stub);
DEFINE_DISPATCH(maximum_stub);
DEFINE_DISPATCH(minimum_stub);
DEFINE_DISPATCH(fmax_stub);
DEFINE_DISPATCH(fmin_stub);
DEFINE_DISPATCH(fmod_stub);
DEFINE_DISPATCH(logaddexp_stub);
DEFINE_DISPATCH(logaddexp2_stub);
Expand Down Expand Up @@ -847,6 +849,23 @@ Tensor max(const Tensor& self, const Tensor& other) {
return at::maximum(self, other);
}

Tensor& fmax_out(const Tensor& self, const Tensor& other, Tensor& result) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmax not implemented for complex tensors.");

auto iter = TensorIterator::binary_op(result, self, other);
fmax_stub(iter.device_type(), iter);
return result;
}

Tensor fmax(const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmax not implemented for complex tensors.");

Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
fmax_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
minimum_stub(iter.device_type(), iter);
Expand All @@ -869,6 +888,23 @@ Tensor min(const Tensor& self, const Tensor& other) {
return at::minimum(self, other);
}

Tensor& fmin_out(const Tensor& self, const Tensor& other, Tensor& result) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmin not implemented for complex tensors.");

auto iter = TensorIterator::binary_op(result, self, other);
fmin_stub(iter.device_type(), iter);
return result;
}

Tensor fmin(const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex() && !other.is_complex(), "fmin not implemented for complex tensors.");

Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
fmin_stub(iter.device_type(), iter);
return iter.output();
}

Tensor floor_divide(const Tensor& self, Scalar other) {
return at::floor_divide(self, wrapped_scalar_tensor(other));
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -58,6 +58,8 @@ DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
DECLARE_DISPATCH(binary_fn, maximum_stub);
DECLARE_DISPATCH(binary_fn, minimum_stub);
DECLARE_DISPATCH(binary_fn, fmax_stub);
DECLARE_DISPATCH(binary_fn, fmin_stub);
DECLARE_DISPATCH(binary_fn_beta, smooth_l1_stub);
DECLARE_DISPATCH(binary_fn, sigmoid_backward_stub);
DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -505,6 +505,32 @@ void minimum_kernel(TensorIterator& iter) {
}
}

void fmax_kernel(TensorIterator& iter) {
if (isFloatingType(iter.common_dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmax_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return std::fmax(a, b);
});
});
} else {
maximum_kernel(iter);
}
}

void fmin_kernel(TensorIterator& iter) {
if (isFloatingType(iter.common_dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmin_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return std::fmin(a, b);
});
});
} else {
minimum_kernel(iter);
}
}

void smooth_l1_kernel(TensorIterator& iter, double beta) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.dtype(), "smooth_l1_cpu", [&]() {
Expand Down Expand Up @@ -854,6 +880,8 @@ REGISTER_DISPATCH(eq_stub, &eq_kernel);
REGISTER_DISPATCH(ne_stub, &ne_kernel);
REGISTER_DISPATCH(maximum_stub, &maximum_kernel);
REGISTER_DISPATCH(minimum_stub, &minimum_kernel);
REGISTER_DISPATCH(fmax_stub, &fmax_kernel);
REGISTER_DISPATCH(fmin_stub, &fmin_kernel);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel);
REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel);
REGISTER_DISPATCH(logit_backward_stub, &logit_backward_kernel);
Expand Down
26 changes: 26 additions & 0 deletions aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu
Expand Up @@ -62,7 +62,33 @@ void minimum_kernel_cuda(TensorIterator& iter) {
}
}

void fmax_kernel_cuda(TensorIterator& iter) {
if (isFloatingType(iter.common_dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmax_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::fmax(a, b);
});
});
} else {
maximum_kernel_cuda(iter);
}
}

void fmin_kernel_cuda(TensorIterator& iter) {
if (isFloatingType(iter.common_dtype())) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmin_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::fmin(a, b);
});
});
} else {
minimum_kernel_cuda(iter);
}
}

REGISTER_DISPATCH(maximum_stub, &maximum_kernel_cuda);
REGISTER_DISPATCH(minimum_stub, &minimum_kernel_cuda);
REGISTER_DISPATCH(fmax_stub, &fmax_kernel_cuda);
REGISTER_DISPATCH(fmin_stub, &fmin_kernel_cuda);

}} // namespace at::native
22 changes: 22 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6241,12 +6241,34 @@
CPU, CUDA: min
QuantizedCPU: min_quantized_cpu

- func: fmin(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU, CUDA: fmin

- func: fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: full
dispatch:
CPU, CUDA: fmin_out

- func: max(Tensor self) -> Tensor
variants: method, function
dispatch:
CPU, CUDA: max
QuantizedCPU: max_quantized_cpu

- func: fmax(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU, CUDA: fmax

- func: fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: full
dispatch:
CPU, CUDA: fmax_out

- func: maximum(Tensor self, Tensor other) -> Tensor
variants: method, function
dispatch:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Expand Up @@ -288,6 +288,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: diagflat
.. automethod:: diagonal
.. automethod:: fill_diagonal_
.. automethod:: fmax
.. automethod:: fmin
.. automethod:: digamma
.. automethod:: digamma_
.. automethod:: dim
Expand Down
2 changes: 2 additions & 0 deletions docs/source/torch.rst
Expand Up @@ -416,6 +416,8 @@ Comparison Ops
less
maximum
minimum
fmax
fmin
ne
not_equal
sort
Expand Down
35 changes: 22 additions & 13 deletions test/test_binary_ufuncs.py
Expand Up @@ -979,13 +979,14 @@ def test_binary_ops_with_scalars(self, device):
def test_maximum_minimum_type_promotion(self, device, dtypes):
a = torch.tensor((0, 1), device=device, dtype=dtypes[0])
b = torch.tensor((1, 0), device=device, dtype=dtypes[1])
for op in (torch.maximum, torch.max, torch.minimum, torch.min):
for op in (torch.maximum, torch.max, torch.fmax, torch.minimum, torch.min, torch.fmin):
result = op(a, b)
self.assertEqual(result.dtype, torch.result_type(a, b))

@dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool]))
def test_maximum_minimum_int_and_bool(self, device, dtype):
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum))
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum),
(torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin))
rng = np.random.default_rng()
a_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype])
b_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype])
Expand All @@ -994,21 +995,24 @@ def test_maximum_minimum_int_and_bool(self, device, dtype):
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
tensor_result = torch_op(a_tensor, b_tensor)
alias_result = alias(a_tensor, b_tensor)

out = torch.empty_like(a_tensor)
torch_op(a_tensor, b_tensor, out=out)

numpy_result = numpy_op(a_np, b_np)

self.assertEqual(alias_result, tensor_result)
if alias is not None:
alias_result = alias(a_tensor, b_tensor)
self.assertEqual(alias_result, tensor_result)

self.assertEqual(tensor_result, numpy_result)
self.assertEqual(out, numpy_result)

@precisionOverride({torch.bfloat16: 1e-2})
@dtypes(*(torch.testing.get_all_fp_dtypes()))
def test_maximum_minimum_float(self, device, dtype):
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum))
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum),
(torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin))

if dtype == torch.bfloat16:
a_np = np.random.randn(10).astype(np.float64)
Expand All @@ -1023,21 +1027,24 @@ def test_maximum_minimum_float(self, device, dtype):
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
tensor_result = torch_op(a_tensor, b_tensor)
alias_result = alias(a_tensor, b_tensor)
out = torch.empty_like(a_tensor)
torch_op(a_tensor, b_tensor, out=out)

self.assertEqual(alias_result, tensor_result)
if alias is not None:
alias_result = alias(a_tensor, b_tensor)
self.assertEqual(alias_result, tensor_result)

self.assertEqual(tensor_result, numpy_result)
self.assertEqual(out, numpy_result)

@dtypes(*(torch.testing.get_all_fp_dtypes()))
def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
# np.maximum and np.minimum functions compare input arrays element-wisely.
# if one of the elements being compared is a NaN, then that element is returned.
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum))
a_vals = (float('inf'), -float('inf'), float('nan'), float('nan'))
b_vals = (-float('inf'), float('inf'), float('inf'), float('nan'))
ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum),
(torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin))
a_vals = (float('inf'), -float('inf'), float('nan'), float('inf'), float('nan'), float('nan'), 1, float('nan'))
b_vals = (-float('inf'), float('inf'), float('inf'), float('nan'), float('nan'), 0, float('nan'), -5)
if dtype == torch.bfloat16:
a_np = np.array(a_vals, dtype=np.float64)
b_np = np.array(b_vals, dtype=np.float64)
Expand All @@ -1051,12 +1058,14 @@ def test_maximum_minimum_float_nan_and_inf(self, device, dtype):
a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype)
b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype)
tensor_result = torch_op(a_tensor, b_tensor)
alias_result = alias(a_tensor, b_tensor)

out = torch.empty_like(a_tensor)
torch_op(a_tensor, b_tensor, out=out)

self.assertEqual(alias_result, tensor_result)
if alias is not None:
alias_result = alias(a_tensor, b_tensor)
self.assertEqual(alias_result, tensor_result)

if dtype == torch.bfloat16:
self.assertEqual(tensor_result, numpy_result, exact_dtype=False)
self.assertEqual(out, numpy_result, exact_dtype=False)
Expand All @@ -1066,7 +1075,7 @@ def test_maximum_minimum_float_nan_and_inf(self, device, dtype):

@dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes()))
def test_maximum_minimum_complex(self, device, dtypes):
for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min):
for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min, torch.fmax, torch.fmin):
with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'):
torch_op(torch.ones(1, device=device, dtype=dtypes[0]),
torch.ones(1, device=device, dtype=dtypes[1]))
Expand Down
8 changes: 8 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -717,6 +717,10 @@
self: grad.clone().masked_fill_(self <= other, 0)
other: grad.clone().masked_fill_(self > other, 0)

- name: fmax(Tensor self, Tensor other) -> Tensor
self: grad.clone().masked_fill_((self >= other).logical_or_(other.isnan()).logical_not_(), 0)
other: grad.clone().masked_fill_((self >= other).logical_or_(other.isnan()), 0)

- name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
self: grad.expand(self.sizes()).to(self.scalar_type()) / self.numel()

Expand Down Expand Up @@ -759,6 +763,10 @@
self: grad.clone().masked_fill_(self >= other, 0)
other: grad.clone().masked_fill_(self < other, 0)

- name: fmin(Tensor self, Tensor other) -> Tensor
self: grad.clone().masked_fill_((self <= other).logical_or_(other.isnan()).logical_not_(), 0)
other: grad.clone().masked_fill_((self <= other).logical_or_(other.isnan()), 0)

- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)

Expand Down
14 changes: 14 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -2273,6 +2273,13 @@ def callable(a, b) -> number
See :func:`torch.maximum`
""")

add_docstr_all('fmax',
r"""
fmax(other) -> Tensor
See :func:`torch.fmax`
""")

add_docstr_all('argmax',
r"""
argmax(dim=None, keepdim=False) -> LongTensor
Expand Down Expand Up @@ -2322,6 +2329,13 @@ def callable(a, b) -> number
See :func:`torch.minimum`
""")

add_docstr_all('fmin',
r"""
fmin(other) -> Tensor
See :func:`torch.fmin`
""")

add_docstr_all('argmin',
r"""
argmin(dim=None, keepdim=False) -> LongTensor
Expand Down

0 comments on commit 4803eaf

Please sign in to comment.