Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] torch.digamma : promote integer inputs to float #48302

Closed
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
58bcc19
digamma: int -> float promotion
kshitij12345 Nov 20, 2020
1bcced1
update test
kshitij12345 Nov 20, 2020
085cbfe
match negative integer input and zero input against scipy
kshitij12345 Nov 20, 2020
8a2e4bb
undo changes on polygamma
kshitij12345 Nov 20, 2020
65805b8
match scipy behaviour at 0
kshitij12345 Nov 20, 2020
034a989
add test with special values based on SciPy test
kshitij12345 Nov 21, 2020
a90952c
add comments for special value return
kshitij12345 Nov 28, 2020
7aaa3d5
update docs
kshitij12345 Nov 28, 2020
6e54eb0
Merge branch 'master' into develop/numpy/unary-float-op/digamma
kshitij12345 Dec 8, 2020
63d322d
remove MathTestMeta entry for digamma
kshitij12345 Dec 8, 2020
401d443
test digamma to use scipy reference operator
kshitij12345 Dec 8, 2020
630ebad
remove stray space
kshitij12345 Dec 8, 2020
d011561
handle failure with test_variant_consistency_jit
kshitij12345 Dec 8, 2020
18aacb1
remove problematic value
kshitij12345 Dec 8, 2020
b82855d
Merge branch 'master' into develop/numpy/unary-float-op/digamma
kshitij12345 Dec 15, 2020
eb2e176
update Note name and use trunc
kshitij12345 Dec 17, 2020
7684f3b
use iter.common_dtype
kshitij12345 Dec 17, 2020
6ecf8df
remove redundant test
kshitij12345 Dec 17, 2020
644bc74
update the skips
kshitij12345 Dec 17, 2020
ddbcb6c
update docs
kshitij12345 Dec 17, 2020
d84bd1f
update docs related to term non-positive
kshitij12345 Dec 17, 2020
4c44a6d
wording revision
Dec 21, 2020
3e26162
Merge branch 'master' into develop/numpy/unary-float-op/digamma
kshitij12345 Dec 21, 2020
332c48e
remove digamma from tensor_op_tests
kshitij12345 Dec 21, 2020
83e890a
remove skip
kshitij12345 Dec 21, 2020
3119436
Merge branch 'master' into develop/numpy/unary-float-op/digamma
kshitij12345 Dec 22, 2020
a3375c2
add float16 skip
kshitij12345 Dec 22, 2020
86a33a4
fix argument name
kshitij12345 Dec 22, 2020
e16fd9c
make lint happy : int -> bool x_is_integer
kshitij12345 Dec 22, 2020
22fbfc5
Merge branch 'master' into develop/numpy/unary-float-op/digamma
kshitij12345 Dec 23, 2020
d518ed5
Merge branch 'develop/numpy/unary-float-op/digamma' into develop/nump…
kshitij12345 Dec 23, 2020
86bf59b
fix stray merge
kshitij12345 Dec 23, 2020
d3e4dee
Merge branch 'master' into develop/numpy/unary-float-op/digamma
kshitij12345 Dec 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 16 additions & 6 deletions aten/src/ATen/native/Math.h
Expand Up @@ -277,15 +277,20 @@ static inline float trigamma(float x) {
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline double calc_digamma(double x) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
static double PSI_10 = 2.25175258906672110764;
if (x == 0) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
}

int x_is_integer = x == floor(x);
bool x_is_integer = x == trunc(x);
if (x < 0) {
if (x_is_integer) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return NAN;
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
}
return calc_digamma(1 - x) - M_PI / tan(M_PI * x);
}
Expand Down Expand Up @@ -324,15 +329,20 @@ static inline double calc_digamma(double x) {
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline float calc_digamma(float x) {
// See [C++ Standard Reference: Gamma Function]
static float PSI_10 = 2.25175258906672110764f;
if (x == 0) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
}

int x_is_integer = x == floorf(x);
bool x_is_integer = x == truncf(x);
if (x < 0) {
if (x_is_integer) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return NAN;
}
// Avoid rounding errors for `tan`'s input.
// Those make a big difference at extreme values.
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -304,8 +304,8 @@ Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out
Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); }
Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); }

Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, digamma_stub); }
Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); }
Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, digamma_stub); }
Tensor digamma(const Tensor& self) { return unary_op_impl_float(self, digamma_stub); }
Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); }

Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, reciprocal_stub); }
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -360,7 +360,7 @@ static void atanh_kernel(TensorIterator& iter) {
}

static void digamma_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "digamma", [&]() {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "digamma", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return calc_digamma(a); });
Expand Down
11 changes: 8 additions & 3 deletions aten/src/ATen/native/cuda/Math.cuh
Expand Up @@ -93,6 +93,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
*/
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
Expand All @@ -108,14 +109,18 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {

accscalar_t x = static_cast<accscalar_t>(in);
if (x == 0) {
return static_cast<scalar_t>(INFINITY);
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(static_cast<scalar_t>(INFINITY), -x);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments for CUDA, too

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the static_cast needed for INFINITY here? It doesn't appear in the cpu code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CPU, there are two functions for float and double

For CUDA, there is one which is templatized function.

But I don't think it is necessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. Let's just leave it.

}

bool x_is_integer = x == ::floor(x);
bool x_is_integer = x == ::trunc(x);
accscalar_t result = 0;
if (x < 0) {
if (x_is_integer) {
return static_cast<scalar_t>(INFINITY);
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return static_cast<scalar_t>(NAN);
}
// Rounding errors in tan's input can really affect the output
// for extreme values, so we always perform this computation in double.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryGammaKernels.cu
Expand Up @@ -11,7 +11,7 @@
namespace at { namespace native {

void digamma_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "digamma_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "digamma_cuda", [&]() {
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_digamma(a);
});
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Expand Up @@ -6927,7 +6927,6 @@ def inner(self, device, dtype):
('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]),
('digamma', 'op', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e0, _float_types_no_half),
]

# Creates and decorates a generic test and adds it to the class.
Expand Down
56 changes: 29 additions & 27 deletions test/test_unary_ufuncs.py
Expand Up @@ -495,6 +495,35 @@ def test_sqrt_complex_edge_values(self, device, dtype):
x = torch.tensor(-1.0000e+20 - 4988429.2000j, dtype=dtype, device=device)
self.compare_with_numpy(torch.sqrt, np.sqrt, x)

@unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
@dtypes(torch.float, torch.double)
def test_digamma_special(self, device, dtype):
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
# Based on SciPy test for the following special values.
# Reference:
# https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22
euler = 0.57721566490153286
dataset = [(0., -0.),
(1, -euler),
(0.5, -2 * math.log(2) - euler),
(1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler),
(1 / 4, -math.pi / 2 - 3 * math.log(2) - euler),
(1 / 6, -math.pi * math.sqrt(3) / 2 - 2 * math.log(2) - 3 * math.log(3) / 2 - euler),
(1 / 8, -math.pi / 2 - 4 * math.log(2) -
(math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) / math.sqrt(2) - euler)]
x = torch.tensor(dataset, device=device, dtype=dtype)
self.compare_with_numpy(torch.digamma, scipy.special.digamma, x)

@unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
@dtypes(torch.float, torch.double)
def test_digamma(self, device, dtype):
# Tests pole behavior
# TODO: Add value `-1931.99999994`, to the tensor below when
# https://github.com/pytorch/pytorch/issues/49015 is fixed
tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111,
-100.99999994, 0.000000111,
-0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device)
self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor)

# TODO opinfo mvlgamma
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_mvlgamma(self, device):
Expand Down Expand Up @@ -1138,30 +1167,6 @@ def test_polygamma(self, device, dtype):
torch.autograd.gradcheck(lambda x: x.polygamma(n),
cpu_tensor)

# Note: fails when using float tensors
# TODO: update this test to just compare against NumPy
@onlyCUDA
@dtypes(torch.double)
def test_digamma(self, device, dtype):
cpu_tensor = torch.randn(10, 10, 10, dtype=dtype)
device_tensor = cpu_tensor.to(device)
zeros = torch.zeros(10, 10, 10, dtype=dtype)
cpu_out = cpu_tensor.digamma()
device_out = device_tensor.digamma()
norm_errors = (device_out - cpu_out.to(device)) / device_out
self.assertEqual(norm_errors, zeros)

# Tests pole behavior
cpu_tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111,
-100.99999994, -1931.99999994, 0.000000111,
-0.000000111, 0, -1, -2, -931], dtype=dtype)
expected_errors = torch.tensor([0, 0, 0, 0, 0, 0, 0, nan, nan, nan, nan], dtype=dtype)
device_tensor = cpu_tensor.to(device)
cpu_out = cpu_tensor.digamma()
device_out = device_tensor.digamma()
norm_errors = (device_out - cpu_out.to(device)) / device_out
self.assertEqual(norm_errors, expected_errors)

# TODO: update to compare against NumPy by rationalizing with OpInfo
@onlyCUDA
@dtypes(torch.float, torch.double)
Expand Down Expand Up @@ -1743,9 +1748,6 @@ def _medium_2d(dtype, device):
_TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma',
refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
ref_backend='scipy', rtol=0.0008, atol=1e-5),
_TorchMathTestMeta('digamma',
input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy',
replace_inf_with_nan=True),
_TorchMathTestMeta('abs', input_fn=_medium_2d, dtypes=_types_no_half, rtol=0., atol=0.),
_TorchMathTestMeta('logit', ref_backend='scipy')]

Expand Down
8 changes: 6 additions & 2 deletions torch/_torch_docs.py
Expand Up @@ -2527,8 +2527,7 @@ def merge_dicts(*dicts):
[ 1.0500, 0.7336, -0.3836, -1.1015]]])
""".format(**common_args))

add_docstr(torch.digamma,
r"""
add_docstr(torch.digamma, r"""
digamma(input, *, out=None) -> Tensor

Computes the logarithmic derivative of the gamma function on `input`.
Expand All @@ -2542,6 +2541,11 @@ def merge_dicts(*dicts):
Keyword args:
{out}

.. note:: This function is similar to SciPy's `scipy.special.digamma`.

.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`.
Previously it returned `NaN` for `0`.

Example::

>>> a = torch.tensor([1, 0.5])
Expand Down
13 changes: 13 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1089,6 +1089,19 @@ def reference_sigmoid(x):
dtypes=[torch.bfloat16]),),
assert_autodiffed=True,
promotes_integers_to_float=True),
UnaryUfuncInfo('digamma',
ref=scipy.special.digamma,
decorators=(precisionOverride({torch.float16: 5e-1}),),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
skips=(
# In some cases, output is NaN (for input close to
# negative integers) especially due to reduced precision
# in float16 and NaN's can't be tested for equality.
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.float16]),),
promotes_integers_to_float=True)
]
op_db = op_db + op_db_scipy_reference

Expand Down