Skip to content

Commit

Permalink
[numpy] Add torch.xlogy (#48777)
Browse files Browse the repository at this point in the history
Summary:
Reference #38349
Fixes #22656

TODO:
* [x] Add docs
* [x] Add tests

Pull Request resolved: #48777

Reviewed By: ngimel

Differential Revision: D25681346

Pulled By: mruberry

fbshipit-source-id: 369e0a29ac8a2c44de95eec115bf75943fe1aa45
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Dec 22, 2020
1 parent be09160 commit 2780400
Show file tree
Hide file tree
Showing 15 changed files with 366 additions and 9 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -436,6 +436,7 @@ _(aten, logdet) \
_(aten, logit) \
_(aten, logspace) \
_(aten, logsumexp) \
_(aten, xlogy) \
_(aten, lstm) \
_(aten, lstm_cell) \
_(aten, lstsq) \
Expand Down
38 changes: 38 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -62,6 +62,7 @@ DEFINE_DISPATCH(igammac_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);
DEFINE_DISPATCH(xlogy_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -1101,5 +1102,42 @@ Tensor& ldexp_(Tensor& self, const Tensor& other) {
return at::ldexp_out(self, self, other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
xlogy_stub(iter.device_type(), iter);
return result;
}

Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) {
return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) {
return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device()));
}

Tensor xlogy(const Tensor& x, const Tensor& y) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, x, y);
xlogy_stub(iter.device_type(), iter);
return iter.output();
}

Tensor xlogy(Scalar x, const Tensor& y) {
return at::xlogy(c10::scalar_to_tensor(x, y.device()), y);
}

Tensor xlogy(const Tensor& x, Scalar y) {
return at::xlogy(x, c10::scalar_to_tensor(y, x.device()));
}

Tensor& xlogy_(Tensor& x, const Tensor& y) {
return at::xlogy_out(x, x, y);
}

Tensor& xlogy_(Tensor& x, Scalar y) {
return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device()));
}

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -74,5 +74,6 @@ DECLARE_DISPATCH(binary_fn, igammac_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);
DECLARE_DISPATCH(binary_fn, xlogy_stub);

}} // namespace at::native
15 changes: 15 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -818,6 +818,20 @@ void copysign_kernel(TensorIterator& iter) {
});
}

void xlogy_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() {
cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
if (at::_isnan(y)){
return NAN;
}
if (x == 0){
return 0;
}
return x * std::log(y);
});
});
}

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
Expand Down Expand Up @@ -859,6 +873,7 @@ REGISTER_DISPATCH(igammac_stub, &igammac_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);

} // namespace native
} // namespace at
16 changes: 16 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Expand Up @@ -3,6 +3,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/NumericUtils.h>

// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
Expand All @@ -29,8 +30,23 @@ void mse_kernel_cuda(TensorIterator& iter) {
});
}

void xlogy_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
if (at::_isnan(y)){
return NAN;
}
if (x == 0){
return 0;
}
return x * std::log(y);
});
});
}

REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);

// DO NOT ADD ANY NEW KERNELS HERE
// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.
Expand Down
50 changes: 50 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2560,6 +2560,56 @@
dispatch:
DefaultBackend: logaddexp2

- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy

- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function
dispatch:
CPU, CUDA: xlogy

- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy

# xlogy: inplace variant
- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy_

- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy_

# xlogy: out variant
- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
variants: function
dispatch:
CPU, CUDA: xlogy_out

- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
variants: function
dispatch:
CPU, CUDA: xlogy_out

- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
variants: function
dispatch:
CPU, CUDA: xlogy_out

- func: logdet(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Expand Up @@ -645,6 +645,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: view
.. automethod:: view_as
.. automethod:: where
.. automethod:: xlogy
.. automethod:: xlogy_
.. automethod:: zero_

.. class:: BoolTensor()
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -350,6 +350,7 @@ Pointwise Ops
tanh
true_divide
trunc
xlogy

Reduction Ops
~~~~~~~~~~~~~~~~~~~~~~
Expand Down
50 changes: 49 additions & 1 deletion test/test_autograd.py
Expand Up @@ -10,7 +10,7 @@
import warnings
from copy import deepcopy
from collections import OrderedDict
from itertools import product
from itertools import product, permutations
from operator import mul
from functools import reduce
import torch
Expand Down Expand Up @@ -7396,6 +7396,54 @@ def test_atleast(self, device):
self._test_atleast(device, torch.atleast_2d)
self._test_atleast(device, torch.atleast_3d)

def test_xlogy(self, device):

def _tensor_tensor_helper(x, y):
gradcheck(lambda x, y: torch.xlogy(x, y), (x, y))
gradgradcheck(lambda x, y: torch.xlogy(x, y), (x, y))

with torch.no_grad():
x = x.clone()
x[torch.rand_like(x) > 0.5] = 0

gradcheck(lambda y: torch.xlogy(x, y), (y))
gradgradcheck(lambda y: torch.xlogy(x, y), (y))

shapes = ((4,), (1, 4), (1, 1, 4), (1, 1, 1, 4))

# For broadcastible shapes and scalar.
for x_shape, y_shape in permutations(shapes, 2):
x = torch.rand(*x_shape, dtype=torch.double, device=device, requires_grad=True)
y = torch.rand(*y_shape, dtype=torch.double, device=device, requires_grad=True)

_tensor_tensor_helper(x, y)
_tensor_tensor_helper(y, x)

gradcheck(lambda y: torch.xlogy(0, y), (y))
gradgradcheck(lambda y: torch.xlogy(0, y), (y))

gradcheck(lambda y: torch.xlogy(2, y), (y))
gradgradcheck(lambda y: torch.xlogy(2, y), (y))
gradcheck(lambda y: torch.xlogy(y, 2), (y))
gradgradcheck(lambda y: torch.xlogy(y, 2), (y))

# Different shape
x = torch.rand(2, 3, 4, 5, dtype=torch.double, device=device, requires_grad=True)
y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
_tensor_tensor_helper(x, y)
_tensor_tensor_helper(y, x)
_tensor_tensor_helper(x, x)
_tensor_tensor_helper(y, y)

# Same shape
x = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
_tensor_tensor_helper(x, y)
_tensor_tensor_helper(y, x)
_tensor_tensor_helper(x, x)
_tensor_tensor_helper(y, y)


class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None):
threads = []
Expand Down
105 changes: 103 additions & 2 deletions test/test_binary_ufuncs.py
Expand Up @@ -8,15 +8,19 @@
import unittest
import warnings
import operator
from functools import partial

from torch._six import inf, nan
from torch.testing._internal.common_utils import (
TestCase, iter_indices, TEST_WITH_ASAN, run_tests,
torch_to_numpy_dtype_dict, make_tensor)
torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA,
dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA,
skipCUDAIfRocm)
skipCUDAIfRocm, skipIf)

if TEST_SCIPY:
import scipy.special

# TODO: remove this
def _generate_input(shape, dtype, device, with_extremal):
Expand Down Expand Up @@ -2488,6 +2492,103 @@ def _promo_helper(x, y):
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
torch.Tensor.float_power_(base.clone(), exp)

@skipIf(not TEST_SCIPY, "Scipy required for the test.")
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False),
torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False)))
def test_xlogy(self, device, dtypes):
def out_variant_helper(torch_fn, x, y):
expected = torch_fn(x, y)
out = torch.empty_like(expected)
torch_fn(x, y, out=out)
self.assertEqual(expected, out)

def inplace_variant_helper(x, y):
if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]:
with self.assertRaisesRegex(RuntimeError,
"can't be cast to the desired output type"):
x.clone().xlogy_(y)
else:
expected = torch.empty_like(x)
torch.xlogy(x, y, out=expected)
inplace_out = x.clone().xlogy_(y)
self.assertEqual(expected, inplace_out)

x_dtype, y_dtype = dtypes

# Tensor-Tensor Test (tensor of same and different shape)
x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)

torch_fn = partial(torch.xlogy, x)
reference_fn = partial(scipy.special.xlogy, x.cpu().numpy())

self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
out_variant_helper(torch.xlogy, x, x)
out_variant_helper(torch.xlogy, x, y)
out_variant_helper(torch.xlogy, x, z)
inplace_variant_helper(x, x)
inplace_variant_helper(x, y)
inplace_variant_helper(x, z)

# Scalar-Tensor Test
torch_fn = partial(torch.xlogy, 3.14)
reference_fn = partial(scipy.special.xlogy, 3.14)

self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
out_variant_helper(torch.xlogy, 3.14, x)
out_variant_helper(torch.xlogy, 3.14, y)
out_variant_helper(torch.xlogy, 3.14, z)

# Special Values Tensor-Tensor
t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
zeros = torch.zeros(6, dtype=y_dtype, device=device)

torch_fn = partial(torch.xlogy, zeros)
reference_fn = partial(scipy.special.xlogy, zeros.cpu().numpy())
self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
out_variant_helper(torch.xlogy, zeros, t)
inplace_variant_helper(zeros, t)

# Special Values Scalar-Tensor
torch_fn = partial(torch.xlogy, 0)
reference_fn = partial(scipy.special.xlogy, 0)
self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
out_variant_helper(torch.xlogy, 0, t)

@skipIf(not TEST_SCIPY, "Scipy required for the test.")
def test_xlogy_bfloat16(self, device):
def _compare_helper(x, y):
x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
expected = torch.from_numpy(scipy.special.xlogy(x_np, y_np))
actual = torch.xlogy(x, y)
self.assertEqual(expected, actual, exact_dtype=False)

x_dtype, y_dtype = torch.bfloat16, torch.bfloat16

# Tensor-Tensor Test (tensor of same and different shape)
x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)

_compare_helper(x, x)
_compare_helper(x, y)
_compare_helper(x, z)

_compare_helper(x, 3.14)
_compare_helper(y, 3.14)
_compare_helper(z, 3.14)

# Special Values Tensor-Tensor
t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
zeros = torch.tensor(5, dtype=y_dtype, device=device)
_compare_helper(t, zeros)
_compare_helper(t, 0.)

tensor_binary_ops = [
'__lt__', '__le__',
Expand Down

0 comments on commit 2780400

Please sign in to comment.