Skip to content

Commit

Permalink
[numpy] torch.digamma : promote integer inputs to float (#48302)
Browse files Browse the repository at this point in the history
Summary:
**BC-breaking Note:**

This PR updates PyTorch's digamma function to be consistent with SciPy's special.digamma function. This changes the result of the digamma function on the nonpositive integers, where the gamma function is not defined. Since the gamma function is undefined at these points, the (typical) derivative of the logarithm of the gamma function is also undefined at these points, and for negative integers this PR updates digamma to return NaN. For zero, however, it returns -inf to be consistent with SciPy.

Interestingly, SciPy made a similar change, which was noticed by at least one user: scipy/scipy#9663 (comment).

SciPy's returning of negative infinity at zero is intentional:
https://github.com/scipy/scipy/blob/59347ae8b86bcc92c339efe213128f64ab6df98c/scipy/special/cephes/psi.c#L163

This change is consistent with the C++ standard for the gamma function:
https://en.cppreference.com/w/cpp/numeric/math/tgamma

**PR Summary:**
Reference #42515

Pull Request resolved: #48302

Reviewed By: ngimel

Differential Revision: D25664087

Pulled By: mruberry

fbshipit-source-id: 1168e81e218bf9fe5b849db0e07e7b22e590cf73
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Dec 25, 2020
1 parent 46cf6d3 commit 963f762
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 43 deletions.
22 changes: 16 additions & 6 deletions aten/src/ATen/native/Math.h
Expand Up @@ -277,15 +277,20 @@ static inline float trigamma(float x) {
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline double calc_digamma(double x) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
static double PSI_10 = 2.25175258906672110764;
if (x == 0) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
}

int x_is_integer = x == floor(x);
bool x_is_integer = x == trunc(x);
if (x < 0) {
if (x_is_integer) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return NAN;
}
return calc_digamma(1 - x) - M_PI / tan(M_PI * x);
}
Expand Down Expand Up @@ -324,15 +329,20 @@ static inline double calc_digamma(double x) {
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline float calc_digamma(float x) {
// See [C++ Standard Reference: Gamma Function]
static float PSI_10 = 2.25175258906672110764f;
if (x == 0) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
}

int x_is_integer = x == floorf(x);
bool x_is_integer = x == truncf(x);
if (x < 0) {
if (x_is_integer) {
return INFINITY;
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return NAN;
}
// Avoid rounding errors for `tan`'s input.
// Those make a big difference at extreme values.
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -318,8 +318,8 @@ Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out
Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); }
Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); }

Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, digamma_stub); }
Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); }
Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, digamma_stub); }
Tensor digamma(const Tensor& self) { return unary_op_impl_float(self, digamma_stub); }
Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); }

Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, reciprocal_stub); }
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Expand Up @@ -360,7 +360,7 @@ static void atanh_kernel(TensorIterator& iter) {
}

static void digamma_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "digamma", [&]() {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "digamma", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return calc_digamma(a); });
Expand Down
11 changes: 8 additions & 3 deletions aten/src/ATen/native/cuda/Math.cuh
Expand Up @@ -93,6 +93,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) {
*/
template <typename scalar_t>
static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
Expand All @@ -108,14 +109,18 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) {

accscalar_t x = static_cast<accscalar_t>(in);
if (x == 0) {
return static_cast<scalar_t>(INFINITY);
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(static_cast<scalar_t>(INFINITY), -x);
}

bool x_is_integer = x == ::floor(x);
bool x_is_integer = x == ::trunc(x);
accscalar_t result = 0;
if (x < 0) {
if (x_is_integer) {
return static_cast<scalar_t>(INFINITY);
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return static_cast<scalar_t>(NAN);
}
// Rounding errors in tan's input can really affect the output
// for extreme values, so we always perform this computation in double.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryGammaKernels.cu
Expand Up @@ -11,7 +11,7 @@
namespace at { namespace native {

void digamma_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "digamma_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "digamma_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_digamma(a);
});
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Expand Up @@ -6925,7 +6925,6 @@ def inner(self, device, dtype):
('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]),
('digamma', 'op', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e0, _float_types_no_half),
]

# Creates and decorates a generic test and adds it to the class.
Expand Down
56 changes: 29 additions & 27 deletions test/test_unary_ufuncs.py
Expand Up @@ -497,6 +497,35 @@ def test_sqrt_complex_edge_values(self, device, dtype):
x = torch.tensor(-1.0000e+20 - 4988429.2000j, dtype=dtype, device=device)
self.compare_with_numpy(torch.sqrt, np.sqrt, x)

@unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
@dtypes(torch.float, torch.double)
def test_digamma_special(self, device, dtype):
# Based on SciPy test for the following special values.
# Reference:
# https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22
euler = 0.57721566490153286
dataset = [(0., -0.),
(1, -euler),
(0.5, -2 * math.log(2) - euler),
(1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler),
(1 / 4, -math.pi / 2 - 3 * math.log(2) - euler),
(1 / 6, -math.pi * math.sqrt(3) / 2 - 2 * math.log(2) - 3 * math.log(3) / 2 - euler),
(1 / 8, -math.pi / 2 - 4 * math.log(2) -
(math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) / math.sqrt(2) - euler)]
x = torch.tensor(dataset, device=device, dtype=dtype)
self.compare_with_numpy(torch.digamma, scipy.special.digamma, x)

@unittest.skipIf(not TEST_SCIPY, "Requires SciPy")
@dtypes(torch.float, torch.double)
def test_digamma(self, device, dtype):
# Tests pole behavior
# TODO: Add value `-1931.99999994`, to the tensor below when
# https://github.com/pytorch/pytorch/issues/49015 is fixed
tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111,
-100.99999994, 0.000000111,
-0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device)
self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor)

# TODO opinfo mvlgamma
@unittest.skipIf(not TEST_SCIPY, "Scipy not found")
def test_mvlgamma(self, device):
Expand Down Expand Up @@ -1120,30 +1149,6 @@ def test_polygamma(self, device, dtype):
torch.autograd.gradcheck(lambda x: x.polygamma(n),
cpu_tensor)

# Note: fails when using float tensors
# TODO: update this test to just compare against NumPy
@onlyCUDA
@dtypes(torch.double)
def test_digamma(self, device, dtype):
cpu_tensor = torch.randn(10, 10, 10, dtype=dtype)
device_tensor = cpu_tensor.to(device)
zeros = torch.zeros(10, 10, 10, dtype=dtype)
cpu_out = cpu_tensor.digamma()
device_out = device_tensor.digamma()
norm_errors = (device_out - cpu_out.to(device)) / device_out
self.assertEqual(norm_errors, zeros)

# Tests pole behavior
cpu_tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111,
-100.99999994, -1931.99999994, 0.000000111,
-0.000000111, 0, -1, -2, -931], dtype=dtype)
expected_errors = torch.tensor([0, 0, 0, 0, 0, 0, 0, nan, nan, nan, nan], dtype=dtype)
device_tensor = cpu_tensor.to(device)
cpu_out = cpu_tensor.digamma()
device_out = device_tensor.digamma()
norm_errors = (device_out - cpu_out.to(device)) / device_out
self.assertEqual(norm_errors, expected_errors)

# TODO: update to compare against NumPy by rationalizing with OpInfo
@onlyCUDA
@dtypes(torch.float, torch.double)
Expand Down Expand Up @@ -1725,9 +1730,6 @@ def _medium_2d(dtype, device):
_TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma',
refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
ref_backend='scipy', rtol=0.0008, atol=1e-5),
_TorchMathTestMeta('digamma',
input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy',
replace_inf_with_nan=True),
_TorchMathTestMeta('abs', input_fn=_medium_2d, dtypes=_types_no_half, rtol=0., atol=0.),
_TorchMathTestMeta('logit', ref_backend='scipy')]

Expand Down
8 changes: 6 additions & 2 deletions torch/_torch_docs.py
Expand Up @@ -2532,8 +2532,7 @@ def merge_dicts(*dicts):
[ 1.0500, 0.7336, -0.3836, -1.1015]]])
""".format(**common_args))

add_docstr(torch.digamma,
r"""
add_docstr(torch.digamma, r"""
digamma(input, *, out=None) -> Tensor
Computes the logarithmic derivative of the gamma function on `input`.
Expand All @@ -2547,6 +2546,11 @@ def merge_dicts(*dicts):
Keyword args:
{out}
.. note:: This function is similar to SciPy's `scipy.special.digamma`.
.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`.
Previously it returned `NaN` for `0`.
Example::
>>> a = torch.tensor([1, 0.5])
Expand Down
13 changes: 13 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1087,6 +1087,19 @@ def reference_sigmoid(x):
promotes_integers_to_float=True,
assert_autodiffed=True,
test_complex_grad=False), # Reference: https://github.com/pytorch/pytorch/issues/48552
UnaryUfuncInfo('digamma',
ref=scipy.special.digamma,
decorators=(precisionOverride({torch.float16: 5e-1}),),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
skips=(
# In some cases, output is NaN (for input close to
# negative integers) especially due to reduced precision
# in float16 and NaN's can't be tested for equality.
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.float16]),),
promotes_integers_to_float=True),
UnaryUfuncInfo('erf',
ref=scipy.special.erf,
decorators=(precisionOverride({torch.float16: 1e-2,
Expand Down

0 comments on commit 963f762

Please sign in to comment.