diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 7b0759c3671b..644d75c04c06 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -436,6 +436,7 @@ _(aten, logdet) \ _(aten, logit) \ _(aten, logspace) \ _(aten, logsumexp) \ +_(aten, xlogy) \ _(aten, lstm) \ _(aten, lstm_cell) \ _(aten, lstsq) \ diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index e8751be55387..9103eafb1f12 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -62,6 +62,7 @@ DEFINE_DISPATCH(igammac_stub); DEFINE_DISPATCH(nextafter_stub); DEFINE_DISPATCH(heaviside_stub); DEFINE_DISPATCH(copysign_stub); +DEFINE_DISPATCH(xlogy_stub); static Tensor wrapped_scalar_tensor(Scalar scalar) { auto tensor = scalar_to_tensor(scalar); @@ -1101,5 +1102,42 @@ Tensor& ldexp_(Tensor& self, const Tensor& other) { return at::ldexp_out(self, self, other); } +Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_float_op(result, self, other); + xlogy_stub(iter.device_type(), iter); + return result; +} + +Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) { + return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other); +} + +Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) { + return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device())); +} + +Tensor xlogy(const Tensor& x, const Tensor& y) { + Tensor result; + auto iter = TensorIterator::binary_float_op(result, x, y); + xlogy_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor xlogy(Scalar x, const Tensor& y) { + return at::xlogy(c10::scalar_to_tensor(x, y.device()), y); +} + +Tensor xlogy(const Tensor& x, Scalar y) { + return at::xlogy(x, c10::scalar_to_tensor(y, x.device())); +} + +Tensor& xlogy_(Tensor& x, const Tensor& y) { + return at::xlogy_out(x, x, y); +} + +Tensor& xlogy_(Tensor& x, Scalar y) { + return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device())); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 1fdb80590b5a..191611875f08 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -74,5 +74,6 @@ DECLARE_DISPATCH(binary_fn, igammac_stub); DECLARE_DISPATCH(binary_fn, nextafter_stub); DECLARE_DISPATCH(binary_fn, heaviside_stub); DECLARE_DISPATCH(binary_fn, copysign_stub); +DECLARE_DISPATCH(binary_fn, xlogy_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index ddfa8a2d3d95..3dfe130ced70 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -818,6 +818,20 @@ void copysign_kernel(TensorIterator& iter) { }); } +void xlogy_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() { + cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t { + if (at::_isnan(y)){ + return NAN; + } + if (x == 0){ + return 0; + } + return x * std::log(y); + }); + }); +} + } // namespace REGISTER_DISPATCH(add_stub, &add_kernel); @@ -859,6 +873,7 @@ REGISTER_DISPATCH(igammac_stub, &igammac_kernel); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel); REGISTER_DISPATCH(copysign_stub, ©sign_kernel); +REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index c0efde1671d1..2379877e91ba 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -3,6 +3,7 @@ #include #include #include +#include // NOTE: CUDA on Windows requires that the enclosing function // of a __device__ lambda not have internal linkage. @@ -29,8 +30,23 @@ void mse_kernel_cuda(TensorIterator& iter) { }); } +void xlogy_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t { + if (at::_isnan(y)){ + return NAN; + } + if (x == 0){ + return 0; + } + return x * std::log(y); + }); + }); +} + REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda); REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda); +REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda); // DO NOT ADD ANY NEW KERNELS HERE // CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel. diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 78ad11229428..9c0053f40b7e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2560,6 +2560,56 @@ dispatch: DefaultBackend: logaddexp2 +- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: xlogy + +- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: function + dispatch: + CPU, CUDA: xlogy + +- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: xlogy + +# xlogy: inplace variant +- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: xlogy_ + +- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: xlogy_ + +# xlogy: out variant +- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU, CUDA: xlogy_out + +- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU, CUDA: xlogy_out + +- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures + variants: function + dispatch: + CPU, CUDA: xlogy_out + - func: logdet(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index f73753743d59..315cc9dc5309 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -645,6 +645,8 @@ view of a storage and defines numeric operations on it. .. automethod:: view .. automethod:: view_as .. automethod:: where + .. automethod:: xlogy + .. automethod:: xlogy_ .. automethod:: zero_ .. class:: BoolTensor() diff --git a/docs/source/torch.rst b/docs/source/torch.rst index c82035eb8684..3057339aa811 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -350,6 +350,7 @@ Pointwise Ops tanh true_divide trunc + xlogy Reduction Ops ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_autograd.py b/test/test_autograd.py index 3d29529cab9a..a8a130596855 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -10,7 +10,7 @@ import warnings from copy import deepcopy from collections import OrderedDict -from itertools import product +from itertools import product, permutations from operator import mul from functools import reduce import torch @@ -7396,6 +7396,54 @@ def test_atleast(self, device): self._test_atleast(device, torch.atleast_2d) self._test_atleast(device, torch.atleast_3d) + def test_xlogy(self, device): + + def _tensor_tensor_helper(x, y): + gradcheck(lambda x, y: torch.xlogy(x, y), (x, y)) + gradgradcheck(lambda x, y: torch.xlogy(x, y), (x, y)) + + with torch.no_grad(): + x = x.clone() + x[torch.rand_like(x) > 0.5] = 0 + + gradcheck(lambda y: torch.xlogy(x, y), (y)) + gradgradcheck(lambda y: torch.xlogy(x, y), (y)) + + shapes = ((4,), (1, 4), (1, 1, 4), (1, 1, 1, 4)) + + # For broadcastible shapes and scalar. + for x_shape, y_shape in permutations(shapes, 2): + x = torch.rand(*x_shape, dtype=torch.double, device=device, requires_grad=True) + y = torch.rand(*y_shape, dtype=torch.double, device=device, requires_grad=True) + + _tensor_tensor_helper(x, y) + _tensor_tensor_helper(y, x) + + gradcheck(lambda y: torch.xlogy(0, y), (y)) + gradgradcheck(lambda y: torch.xlogy(0, y), (y)) + + gradcheck(lambda y: torch.xlogy(2, y), (y)) + gradgradcheck(lambda y: torch.xlogy(2, y), (y)) + gradcheck(lambda y: torch.xlogy(y, 2), (y)) + gradgradcheck(lambda y: torch.xlogy(y, 2), (y)) + + # Different shape + x = torch.rand(2, 3, 4, 5, dtype=torch.double, device=device, requires_grad=True) + y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True) + _tensor_tensor_helper(x, y) + _tensor_tensor_helper(y, x) + _tensor_tensor_helper(x, x) + _tensor_tensor_helper(y, y) + + # Same shape + x = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True) + y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True) + _tensor_tensor_helper(x, y) + _tensor_tensor_helper(y, x) + _tensor_tensor_helper(x, x) + _tensor_tensor_helper(y, y) + + class TestMultithreadAutograd(TestCase): def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None): threads = [] diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 9888c29130bb..5739fb569628 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -8,15 +8,19 @@ import unittest import warnings import operator +from functools import partial from torch._six import inf, nan from torch.testing._internal.common_utils import ( TestCase, iter_indices, TEST_WITH_ASAN, run_tests, - torch_to_numpy_dtype_dict, make_tensor) + torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA, - skipCUDAIfRocm) + skipCUDAIfRocm, skipIf) + +if TEST_SCIPY: + import scipy.special # TODO: remove this def _generate_input(shape, dtype, device, with_extremal): @@ -2488,6 +2492,103 @@ def _promo_helper(x, y): with self.assertRaisesRegex(RuntimeError, "is not the desired type"): torch.Tensor.float_power_(base.clone(), exp) + @skipIf(not TEST_SCIPY, "Scipy required for the test.") + @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False), + torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False))) + def test_xlogy(self, device, dtypes): + def out_variant_helper(torch_fn, x, y): + expected = torch_fn(x, y) + out = torch.empty_like(expected) + torch_fn(x, y, out=out) + self.assertEqual(expected, out) + + def inplace_variant_helper(x, y): + if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]: + with self.assertRaisesRegex(RuntimeError, + "can't be cast to the desired output type"): + x.clone().xlogy_(y) + else: + expected = torch.empty_like(x) + torch.xlogy(x, y, out=expected) + inplace_out = x.clone().xlogy_(y) + self.assertEqual(expected, inplace_out) + + x_dtype, y_dtype = dtypes + + # Tensor-Tensor Test (tensor of same and different shape) + x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000) + y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000) + z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000) + + torch_fn = partial(torch.xlogy, x) + reference_fn = partial(scipy.special.xlogy, x.cpu().numpy()) + + self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False) + out_variant_helper(torch.xlogy, x, x) + out_variant_helper(torch.xlogy, x, y) + out_variant_helper(torch.xlogy, x, z) + inplace_variant_helper(x, x) + inplace_variant_helper(x, y) + inplace_variant_helper(x, z) + + # Scalar-Tensor Test + torch_fn = partial(torch.xlogy, 3.14) + reference_fn = partial(scipy.special.xlogy, 3.14) + + self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False) + self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False) + out_variant_helper(torch.xlogy, 3.14, x) + out_variant_helper(torch.xlogy, 3.14, y) + out_variant_helper(torch.xlogy, 3.14, z) + + # Special Values Tensor-Tensor + t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device) + zeros = torch.zeros(6, dtype=y_dtype, device=device) + + torch_fn = partial(torch.xlogy, zeros) + reference_fn = partial(scipy.special.xlogy, zeros.cpu().numpy()) + self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False) + out_variant_helper(torch.xlogy, zeros, t) + inplace_variant_helper(zeros, t) + + # Special Values Scalar-Tensor + torch_fn = partial(torch.xlogy, 0) + reference_fn = partial(scipy.special.xlogy, 0) + self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False) + out_variant_helper(torch.xlogy, 0, t) + + @skipIf(not TEST_SCIPY, "Scipy required for the test.") + def test_xlogy_bfloat16(self, device): + def _compare_helper(x, y): + x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy() + y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy() + expected = torch.from_numpy(scipy.special.xlogy(x_np, y_np)) + actual = torch.xlogy(x, y) + self.assertEqual(expected, actual, exact_dtype=False) + + x_dtype, y_dtype = torch.bfloat16, torch.bfloat16 + + # Tensor-Tensor Test (tensor of same and different shape) + x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000) + y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000) + z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000) + + _compare_helper(x, x) + _compare_helper(x, y) + _compare_helper(x, z) + + _compare_helper(x, 3.14) + _compare_helper(y, 3.14) + _compare_helper(z, 3.14) + + # Special Values Tensor-Tensor + t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device) + zeros = torch.tensor(5, dtype=y_dtype, device=device) + _compare_helper(t, zeros) + _compare_helper(t, 0.) tensor_binary_ops = [ '__lt__', '__le__', diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7a619b926612..9f68622e7691 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -647,6 +647,16 @@ self: grad / (1 + pow(2, other - self)) other: grad / (1 + pow(2, self - other)) +- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + self: grad * at::xlogy((self != 0), other) + other: grad * self / other + +- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + other: grad * self / other + +- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + self: grad * at::xlogy((self != 0), other) + - name: logdet(Tensor self) -> Tensor self: logdet_backward(grad, self, result) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index f081b595de2f..e9443202785d 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -4472,6 +4472,20 @@ def callable(a, b) -> number Out-of-place version of :meth:`torch.Tensor.masked_scatter_` """) +add_docstr_all('xlogy', + r""" +xlogy(other) -> Tensor + +See :func:`torch.xlogy` +""") + +add_docstr_all('xlogy_', + r""" +xlogy_(other) -> Tensor + +In-place version of :meth:`~Tensor.xlogy` +""") + add_docstr_all('masked_fill', r""" masked_fill(mask, value) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 91da41bf05d4..029494284f39 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4371,6 +4371,48 @@ def merge_dicts(*dicts): {out} """.format(**common_args)) +add_docstr(torch.xlogy, + r""" +xlogy(input, other, *, out=None) -> Tensor + +Computes ``input * log(other)`` with the following cases. + +.. math:: + \text{out}_{i} = \begin{cases} + \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ + 0 & \text{if } \text{input}_{i} = 0.0 \\ + \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise} + \end{cases} + +Similar to SciPy's `scipy.special.xlogy`. + +""" + r""" + +Args: + input (Number or Tensor) + other (Number or Tensor) + +.. note:: At least one of :attr:`input` or :attr:`other` must be a tensor. + +Keyword args: + {out} + +Example:: + + >>> x = torch.zeros(5,) + >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')]) + >>> torch.xlogy(x, y) + tensor([0., 0., 0., 0., nan]) + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([3, 2, 1]) + >>> torch.xlogy(x, y) + tensor([1.0986, 1.3863, 0.0000]) + >>> torch.xlogy(x, 4) + tensor([1.3863, 2.7726, 4.1589]) + >>> torch.xlogy(2, y) + tensor([2.1972, 1.3863, 0.0000]) +""".format(**common_args)) + add_docstr(torch.logical_and, r""" logical_and(input, other, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index c0e34634fd67..d23e34831bdd 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -495,6 +495,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.logaddexp: lambda input, other, out=None: -1, torch.logaddexp2: lambda input, other, out=None: -1, torch.logdet: lambda input: -1, + torch.xlogy: lambda x, y: -1, torch.logical_and: lambda input, other, out=None: -1, torch.logical_not: lambda input, out=None: -1, torch.logical_or: lambda input, other, out=None: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ba29c42f39ff..55b97b38a4da 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -291,12 +291,21 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad): return (SampleInput((make_tensor((S, S), device, dtype, low=None, high=None, requires_grad=requires_grad), - make_tensor((S, S), device, dtype, - low=None, high=None, - requires_grad=requires_grad), - make_tensor((S, S), device, dtype, - low=None, high=None, - requires_grad=False))),) + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=False))),) + + +def sample_inputs_xlogy(self, device, dtype, requires_grad): + return (SampleInput((make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor((S, S), device, dtype, + low=0, high=None, + requires_grad=requires_grad))),) def np_sinc_with_fp16_as_fp32(x): # Wraps numpy's sinc function so that fp16 values are promoted to fp32 @@ -1084,6 +1093,14 @@ def reference_sigmoid(x): dtypes=[torch.bfloat16]),), assert_autodiffed=True, promotes_integers_to_float=True), + OpInfo('xlogy', + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + test_inplace_grad=True, + supports_tensor_out=True, + promotes_integers_to_float=True, + sample_inputs_func=sample_inputs_xlogy), ] op_db = op_db + op_db_scipy_reference