diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 04c6925933a3..6cd0464de921 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -277,15 +277,20 @@ static inline float trigamma(float x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline double calc_digamma(double x) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma static double PSI_10 = 2.25175258906672110764; if (x == 0) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); } - int x_is_integer = x == floor(x); + bool x_is_integer = x == trunc(x); if (x < 0) { if (x_is_integer) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return NAN; } return calc_digamma(1 - x) - M_PI / tan(M_PI * x); } @@ -324,15 +329,20 @@ static inline double calc_digamma(double x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline float calc_digamma(float x) { + // See [C++ Standard Reference: Gamma Function] static float PSI_10 = 2.25175258906672110764f; if (x == 0) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); } - int x_is_integer = x == floorf(x); + bool x_is_integer = x == truncf(x); if (x < 0) { if (x_is_integer) { - return INFINITY; + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return NAN; } // Avoid rounding errors for `tan`'s input. // Those make a big difference at extreme values. diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 4eb1f393e47c..e6dd1bc4afde 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -318,8 +318,8 @@ Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); } Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); } -Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, digamma_stub); } -Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); } +Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, digamma_stub); } +Tensor digamma(const Tensor& self) { return unary_op_impl_float(self, digamma_stub); } Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); } Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, reciprocal_stub); } diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 42a761439ac0..049b3eff6b5b 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -360,7 +360,7 @@ static void atanh_kernel(TensorIterator& iter) { } static void digamma_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "digamma", [&]() { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "digamma", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return calc_digamma(a); }); diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 1daba76c9446..17c30cd00ea7 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -93,6 +93,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { */ template static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type; static const double PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; @@ -108,14 +109,18 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { accscalar_t x = static_cast(in); if (x == 0) { - return static_cast(INFINITY); + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(static_cast(INFINITY), -x); } - bool x_is_integer = x == ::floor(x); + bool x_is_integer = x == ::trunc(x); accscalar_t result = 0; if (x < 0) { if (x_is_integer) { - return static_cast(INFINITY); + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return static_cast(NAN); } // Rounding errors in tan's input can really affect the output // for extreme values, so we always perform this computation in double. diff --git a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu index d752d606474d..97dbeefccc77 100644 --- a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu @@ -11,7 +11,7 @@ namespace at { namespace native { void digamma_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "digamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "digamma_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_digamma(a); }); diff --git a/test/test_torch.py b/test/test_torch.py index 04fadcb65c66..6532c2e5e17d 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6925,7 +6925,6 @@ def inner(self, device, dtype): ('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]), - ('digamma', 'op', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e0, _float_types_no_half), ] # Creates and decorates a generic test and adds it to the class. diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 37ef90514803..776482306f4d 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -497,6 +497,35 @@ def test_sqrt_complex_edge_values(self, device, dtype): x = torch.tensor(-1.0000e+20 - 4988429.2000j, dtype=dtype, device=device) self.compare_with_numpy(torch.sqrt, np.sqrt, x) + @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") + @dtypes(torch.float, torch.double) + def test_digamma_special(self, device, dtype): + # Based on SciPy test for the following special values. + # Reference: + # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22 + euler = 0.57721566490153286 + dataset = [(0., -0.), + (1, -euler), + (0.5, -2 * math.log(2) - euler), + (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler), + (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler), + (1 / 6, -math.pi * math.sqrt(3) / 2 - 2 * math.log(2) - 3 * math.log(3) / 2 - euler), + (1 / 8, -math.pi / 2 - 4 * math.log(2) - + (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) / math.sqrt(2) - euler)] + x = torch.tensor(dataset, device=device, dtype=dtype) + self.compare_with_numpy(torch.digamma, scipy.special.digamma, x) + + @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") + @dtypes(torch.float, torch.double) + def test_digamma(self, device, dtype): + # Tests pole behavior + # TODO: Add value `-1931.99999994`, to the tensor below when + # https://github.com/pytorch/pytorch/issues/49015 is fixed + tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, + -100.99999994, 0.000000111, + -0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device) + self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) + # TODO opinfo mvlgamma @unittest.skipIf(not TEST_SCIPY, "Scipy not found") def test_mvlgamma(self, device): @@ -1120,30 +1149,6 @@ def test_polygamma(self, device, dtype): torch.autograd.gradcheck(lambda x: x.polygamma(n), cpu_tensor) - # Note: fails when using float tensors - # TODO: update this test to just compare against NumPy - @onlyCUDA - @dtypes(torch.double) - def test_digamma(self, device, dtype): - cpu_tensor = torch.randn(10, 10, 10, dtype=dtype) - device_tensor = cpu_tensor.to(device) - zeros = torch.zeros(10, 10, 10, dtype=dtype) - cpu_out = cpu_tensor.digamma() - device_out = device_tensor.digamma() - norm_errors = (device_out - cpu_out.to(device)) / device_out - self.assertEqual(norm_errors, zeros) - - # Tests pole behavior - cpu_tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, - -100.99999994, -1931.99999994, 0.000000111, - -0.000000111, 0, -1, -2, -931], dtype=dtype) - expected_errors = torch.tensor([0, 0, 0, 0, 0, 0, 0, nan, nan, nan, nan], dtype=dtype) - device_tensor = cpu_tensor.to(device) - cpu_out = cpu_tensor.digamma() - device_out = device_tensor.digamma() - norm_errors = (device_out - cpu_out.to(device)) / device_out - self.assertEqual(norm_errors, expected_errors) - # TODO: update to compare against NumPy by rationalizing with OpInfo @onlyCUDA @dtypes(torch.float, torch.double) @@ -1725,9 +1730,6 @@ def _medium_2d(dtype, device): _TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma', refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], ref_backend='scipy', rtol=0.0008, atol=1e-5), - _TorchMathTestMeta('digamma', - input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy', - replace_inf_with_nan=True), _TorchMathTestMeta('abs', input_fn=_medium_2d, dtypes=_types_no_half, rtol=0., atol=0.), _TorchMathTestMeta('logit', ref_backend='scipy')] diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index d46a6b1bcf84..e05784cbcc22 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2532,8 +2532,7 @@ def merge_dicts(*dicts): [ 1.0500, 0.7336, -0.3836, -1.1015]]]) """.format(**common_args)) -add_docstr(torch.digamma, - r""" +add_docstr(torch.digamma, r""" digamma(input, *, out=None) -> Tensor Computes the logarithmic derivative of the gamma function on `input`. @@ -2547,6 +2546,11 @@ def merge_dicts(*dicts): Keyword args: {out} +.. note:: This function is similar to SciPy's `scipy.special.digamma`. + +.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`. + Previously it returned `NaN` for `0`. + Example:: >>> a = torch.tensor([1, 0.5]) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 808506dc6809..87d0baa895e8 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1087,6 +1087,19 @@ def reference_sigmoid(x): promotes_integers_to_float=True, assert_autodiffed=True, test_complex_grad=False), # Reference: https://github.com/pytorch/pytorch/issues/48552 + UnaryUfuncInfo('digamma', + ref=scipy.special.digamma, + decorators=(precisionOverride({torch.float16: 5e-1}),), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + skips=( + # In some cases, output is NaN (for input close to + # negative integers) especially due to reduced precision + # in float16 and NaN's can't be tested for equality. + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cuda', dtypes=[torch.float16]),), + promotes_integers_to_float=True), UnaryUfuncInfo('erf', ref=scipy.special.erf, decorators=(precisionOverride({torch.float16: 1e-2,