Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] Add torch.xlogy #48777

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5296f17
init xlogy implementation
kshitij12345 Dec 3, 2020
f0125e5
add autograd tests
kshitij12345 Dec 3, 2020
df62075
fix nan hadling logic
kshitij12345 Dec 3, 2020
44ad8b6
add docs
kshitij12345 Dec 3, 2020
c61a4f8
update torch.rst
kshitij12345 Dec 3, 2020
b7baea2
add tests
kshitij12345 Dec 3, 2020
563349c
add method, in-place and out kwarg variants
kshitij12345 Dec 5, 2020
60abfd4
add entry in opinfo db
kshitij12345 Dec 5, 2020
2c662de
update docs as per review
kshitij12345 Dec 5, 2020
fb3fea3
add documentation for method variants
kshitij12345 Dec 5, 2020
9239b6d
fix doc note
kshitij12345 Dec 5, 2020
0d8c3dc
update docs
kshitij12345 Dec 5, 2020
a24c553
Merge branch 'master' into develop/numpy/xlogy
kshitij12345 Dec 7, 2020
30f0c2a
add test for out variant
kshitij12345 Dec 7, 2020
15f8953
add test for bfloat16
kshitij12345 Dec 7, 2020
9d74bbc
update docs
kshitij12345 Dec 7, 2020
a897d21
fix: missing backslash
kshitij12345 Dec 7, 2020
3d6df14
update docs
kshitij12345 Dec 7, 2020
39e953b
update docs
kshitij12345 Dec 7, 2020
c5b3a08
add test helper for inplace variant
kshitij12345 Dec 8, 2020
62259f1
address comment
kshitij12345 Dec 8, 2020
f5189fb
Merge branch 'master' into develop/numpy/xlogy
kshitij12345 Dec 10, 2020
4b89abc
Merge branch 'master' into develop/numpy/xlogy
kshitij12345 Dec 17, 2020
bb69f52
use common_dtype
kshitij12345 Dec 17, 2020
8ecd50f
update docs
kshitij12345 Dec 17, 2020
3bcc07a
update dispatch backend and c10 dispatcher
kshitij12345 Dec 17, 2020
07bb65c
enable variant_consistency_jit test for bfloat16
kshitij12345 Dec 17, 2020
c943beb
address review
kshitij12345 Dec 21, 2020
1550d56
fix formatting
kshitij12345 Dec 21, 2020
3e0fac2
Merge branch 'master' into develop/numpy/xlogy
kshitij12345 Dec 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -436,6 +436,7 @@ _(aten, logdet) \
_(aten, logit) \
_(aten, logspace) \
_(aten, logsumexp) \
_(aten, xlogy) \
_(aten, lstm) \
_(aten, lstm_cell) \
_(aten, lstsq) \
Expand Down
38 changes: 38 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -62,6 +62,7 @@ DEFINE_DISPATCH(igammac_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);
DEFINE_DISPATCH(xlogy_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -1101,5 +1102,42 @@ Tensor& ldexp_(Tensor& self, const Tensor& other) {
return at::ldexp_out(self, self, other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
xlogy_stub(iter.device_type(), iter);
return result;
}

Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) {
return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) {
return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device()));
}

Tensor xlogy(const Tensor& x, const Tensor& y) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, x, y);
xlogy_stub(iter.device_type(), iter);
return iter.output();
}

Tensor xlogy(Scalar x, const Tensor& y) {
return at::xlogy(c10::scalar_to_tensor(x, y.device()), y);
}

Tensor xlogy(const Tensor& x, Scalar y) {
return at::xlogy(x, c10::scalar_to_tensor(y, x.device()));
}

Tensor& xlogy_(Tensor& x, const Tensor& y) {
return at::xlogy_out(x, x, y);
}

Tensor& xlogy_(Tensor& x, Scalar y) {
return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device()));
}

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -74,5 +74,6 @@ DECLARE_DISPATCH(binary_fn, igammac_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);
DECLARE_DISPATCH(binary_fn, xlogy_stub);

}} // namespace at::native
15 changes: 15 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -818,6 +818,20 @@ void copysign_kernel(TensorIterator& iter) {
});
}

void xlogy_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() {
cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
if (at::_isnan(y)){
return NAN;
}
if (x == 0){
return 0;
}
return x * std::log(y);
});
});
}

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
Expand Down Expand Up @@ -859,6 +873,7 @@ REGISTER_DISPATCH(igammac_stub, &igammac_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);

} // namespace native
} // namespace at
16 changes: 16 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Expand Up @@ -3,6 +3,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/NumericUtils.h>

// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
Expand All @@ -29,8 +30,23 @@ void mse_kernel_cuda(TensorIterator& iter) {
});
}

void xlogy_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
if (at::_isnan(y)){
return NAN;
}
if (x == 0){
return 0;
}
return x * std::log(y);
});
});
}

REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);

// DO NOT ADD ANY NEW KERNELS HERE
// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.
Expand Down
50 changes: 50 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2546,6 +2546,56 @@
dispatch:
DefaultBackend: logaddexp2

- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy

- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function
dispatch:
CPU, CUDA: xlogy

- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy

# xlogy: inplace variant
- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy_

- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: xlogy_

# xlogy: out variant
- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
variants: function
dispatch:
CPU, CUDA: xlogy_out

- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
variants: function
dispatch:
CPU, CUDA: xlogy_out

- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
variants: function
dispatch:
CPU, CUDA: xlogy_out

- func: logdet(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Expand Up @@ -645,6 +645,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: view
.. automethod:: view_as
.. automethod:: where
.. automethod:: xlogy
.. automethod:: xlogy_
.. automethod:: zero_

.. class:: BoolTensor()
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -350,6 +350,7 @@ Pointwise Ops
tanh
true_divide
trunc
xlogy

Reduction Ops
~~~~~~~~~~~~~~~~~~~~~~
Expand Down
50 changes: 49 additions & 1 deletion test/test_autograd.py
Expand Up @@ -10,7 +10,7 @@
import warnings
from copy import deepcopy
from collections import OrderedDict
from itertools import product
from itertools import product, permutations
from operator import mul
from functools import reduce
import torch
Expand Down Expand Up @@ -7315,6 +7315,54 @@ def test_atleast(self, device):
self._test_atleast(device, torch.atleast_2d)
self._test_atleast(device, torch.atleast_3d)

def test_xlogy(self, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad and gradgrad checks are performed for every OpInfo in test_ops.py. Are those tests not sufficient for some reason? Is it because of xlogy's unique behavior at (0, 0)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is to particularly verify the behaviour at zero.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really interesting, cc @albanD. I wonder if we can better support this use case in the future by extending OpInfos with metadata like "special_autograd_values."

The OpInfo test may also generate zeros, although it's not very likely. If it did, would those cause the regular autograd check to fail? I'm trying to understand what the OpInfo-based autograd tests can cover vs. what this extra test needs to validate.


def _tensor_tensor_helper(x, y):
gradcheck(lambda x, y: torch.xlogy(x, y), (x, y))
gradgradcheck(lambda x, y: torch.xlogy(x, y), (x, y))

with torch.no_grad():
x = x.clone()
x[torch.rand_like(x) > 0.5] = 0

gradcheck(lambda y: torch.xlogy(x, y), (y))
gradgradcheck(lambda y: torch.xlogy(x, y), (y))

shapes = ((4,), (1, 4), (1, 1, 4), (1, 1, 1, 4))

# For broadcastible shapes and scalar.
for x_shape, y_shape in permutations(shapes, 2):
x = torch.rand(*x_shape, dtype=torch.double, device=device, requires_grad=True)
y = torch.rand(*y_shape, dtype=torch.double, device=device, requires_grad=True)

_tensor_tensor_helper(x, y)
_tensor_tensor_helper(y, x)

gradcheck(lambda y: torch.xlogy(0, y), (y))
gradgradcheck(lambda y: torch.xlogy(0, y), (y))

gradcheck(lambda y: torch.xlogy(2, y), (y))
gradgradcheck(lambda y: torch.xlogy(2, y), (y))
gradcheck(lambda y: torch.xlogy(y, 2), (y))
gradgradcheck(lambda y: torch.xlogy(y, 2), (y))

# Different shape
x = torch.rand(2, 3, 4, 5, dtype=torch.double, device=device, requires_grad=True)
y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
_tensor_tensor_helper(x, y)
_tensor_tensor_helper(y, x)
_tensor_tensor_helper(x, x)
_tensor_tensor_helper(y, y)

# Same shape
x = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
y = torch.rand(4, 5, dtype=torch.double, device=device, requires_grad=True)
_tensor_tensor_helper(x, y)
_tensor_tensor_helper(y, x)
_tensor_tensor_helper(x, x)
_tensor_tensor_helper(y, y)


class TestMultithreadAutograd(TestCase):
def _run_py_multithread_fn(self, fn, args=(), num_threads=10, kwargs=None):
threads = []
Expand Down
105 changes: 103 additions & 2 deletions test/test_binary_ufuncs.py
Expand Up @@ -8,15 +8,19 @@
import unittest
import warnings
import operator
from functools import partial

from torch._six import inf, nan
from torch.testing._internal.common_utils import (
TestCase, iter_indices, TEST_WITH_ASAN, run_tests,
torch_to_numpy_dtype_dict, make_tensor)
torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA,
dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA,
skipCUDAIfRocm)
skipCUDAIfRocm, skipIf)

if TEST_SCIPY:
import scipy.special

# TODO: remove this
def _generate_input(shape, dtype, device, with_extremal):
Expand Down Expand Up @@ -2487,6 +2491,103 @@ def _promo_helper(x, y):
with self.assertRaisesRegex(RuntimeError, "is not the desired type"):
torch.Tensor.float_power_(base.clone(), exp)

@skipIf(not TEST_SCIPY, "Scipy required for the test.")
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False),
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False)))
def test_xlogy(self, device, dtypes):
def out_variant_helper(torch_fn, x, y):
expected = torch_fn(x, y)
out = torch.empty_like(expected)
torch_fn(x, y, out=out)
self.assertEqual(expected, out)

def inplace_variant_helper(x, y):
if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]:
with self.assertRaisesRegex(RuntimeError,
"can't be cast to the desired output type"):
x.clone().xlogy_(y)
else:
expected = torch.empty_like(x)
torch.xlogy(x, y, out=expected)
inplace_out = x.clone().xlogy_(y)
self.assertEqual(expected, inplace_out)

x_dtype, y_dtype = dtypes

# Tensor-Tensor Test (tensor of same and different shape)
x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)

torch_fn = partial(torch.xlogy, x)
reference_fn = partial(scipy.special.xlogy, x.cpu().numpy())

self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
out_variant_helper(torch.xlogy, x, x)
out_variant_helper(torch.xlogy, x, y)
out_variant_helper(torch.xlogy, x, z)
inplace_variant_helper(x, x)
inplace_variant_helper(x, y)
inplace_variant_helper(x, z)

# Scalar-Tensor Test
torch_fn = partial(torch.xlogy, 3.14)
reference_fn = partial(scipy.special.xlogy, 3.14)

self.compare_with_numpy(torch_fn, reference_fn, x, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, y, exact_dtype=False)
self.compare_with_numpy(torch_fn, reference_fn, z, exact_dtype=False)
out_variant_helper(torch.xlogy, 3.14, x)
out_variant_helper(torch.xlogy, 3.14, y)
out_variant_helper(torch.xlogy, 3.14, z)

# Special Values Tensor-Tensor
t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
zeros = torch.zeros(6, dtype=y_dtype, device=device)

torch_fn = partial(torch.xlogy, zeros)
reference_fn = partial(scipy.special.xlogy, zeros.cpu().numpy())
self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
out_variant_helper(torch.xlogy, zeros, t)
inplace_variant_helper(zeros, t)

# Special Values Scalar-Tensor
torch_fn = partial(torch.xlogy, 0)
reference_fn = partial(scipy.special.xlogy, 0)
self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
out_variant_helper(torch.xlogy, 0, t)

@skipIf(not TEST_SCIPY, "Scipy required for the test.")
def test_xlogy_bfloat16(self, device):
def _compare_helper(x, y):
x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy()
y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy()
expected = torch.from_numpy(scipy.special.xlogy(x_np, y_np))
actual = torch.xlogy(x, y)
self.assertEqual(expected, actual, exact_dtype=False)

x_dtype, y_dtype = torch.bfloat16, torch.bfloat16

# Tensor-Tensor Test (tensor of same and different shape)
x = make_tensor((3, 2, 4, 5), device, x_dtype, low=0.5, high=1000)
y = make_tensor((3, 2, 4, 5), device, y_dtype, low=0.5, high=1000)
z = make_tensor((4, 5), device, y_dtype, low=0.5, high=1000)

_compare_helper(x, x)
_compare_helper(x, y)
_compare_helper(x, z)

_compare_helper(x, 3.14)
_compare_helper(y, 3.14)
_compare_helper(z, 3.14)

# Special Values Tensor-Tensor
t = torch.tensor([0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device)
zeros = torch.tensor(5, dtype=y_dtype, device=device)
_compare_helper(t, zeros)
_compare_helper(t, 0.)

tensor_binary_ops = [
'__lt__', '__le__',
Expand Down