From 58bcc199cd8ed7ef2d4c83e0ebee28ac492317d6 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 20 Nov 2020 02:53:46 -0600 Subject: [PATCH 01/26] digamma: int -> float promotion --- aten/src/ATen/native/UnaryOps.cpp | 14 ++++++++------ aten/src/ATen/native/cuda/UnaryGammaKernels.cu | 6 +++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 8e01aff472ff..8a34d6649d0a 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -304,8 +304,8 @@ Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); } Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); } -Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, digamma_stub); } -Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); } +Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, digamma_stub); } +Tensor digamma(const Tensor& self) { return unary_op_impl_float(self, digamma_stub); } Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); } Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, reciprocal_stub); } @@ -602,16 +602,18 @@ Tensor& clip_(Tensor& self, optional min, optional max) { } Tensor polygamma(int64_t n, const Tensor& self) { - Tensor result = at::empty({0}, self.options()); - at::polygamma_out(result, n, self); - return result; + TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); + Tensor result; + auto iter = TensorIterator::unary_float_op(result, self); + polygamma_stub(iter.device_type(), iter, n); + return iter.output(); } Tensor& polygamma_(Tensor& self, int64_t n) { return at::polygamma_out(self, n, self); } Tensor& polygamma_out(Tensor& result, int64_t n, const Tensor& self) { TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); - auto iter = TensorIterator::unary_op(result, self); + auto iter = TensorIterator::unary_float_op(result, self); polygamma_stub(iter.device_type(), iter, n); return result; } diff --git a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu index d752d606474d..d451aae0a444 100644 --- a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu @@ -11,7 +11,7 @@ namespace at { namespace native { void digamma_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "digamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "digamma_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_digamma(a); }); @@ -19,7 +19,7 @@ void digamma_kernel_cuda(TensorIterator& iter) { } void trigamma_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "trigamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "trigamma_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_trigamma(a); }); @@ -32,7 +32,7 @@ void polygamma_kernel_cuda(TensorIterator& iter, int64_t n) { } else if (n == 1) { trigamma_kernel_cuda(iter); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "polygamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "polygamma_cuda", [&]() { gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_polygamma(int(n), a); }); From 1bcced1774b16d9f12021b93715289f31f939a2c Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 20 Nov 2020 02:54:01 -0600 Subject: [PATCH 02/26] update test --- test/test_unary_ufuncs.py | 2 +- .../_internal/common_methods_invocations.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 95287da20755..ebe66e1d181b 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -236,7 +236,7 @@ def assertEqualHelper(self, actual, expected, msg, *, dtype, exact_dtype=True, * # Allows array dtype to be float32 when comparing with bfloat16 tensors # since NumPy doesn't support the bfloat16 dtype if expected.dtype == np.float32: - assert actual.dtype in (torch.bfloat16, torch.float32) + assert actual.dtype in (torch.float16, torch.bfloat16, torch.float32) else: assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 88e5d9ce0fbc..449922a4701d 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -21,8 +21,10 @@ random_symmetric_matrix, random_symmetric_psd_matrix, random_symmetric_pd_matrix, make_nonzero_det, random_fullrank_matrix_distinct_singular_value, set_rng_seed, - TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor) + TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY) +if TEST_SCIPY: + import scipy.special class SkipInfo(object): """Describes which test, or type of tests, should be skipped when testing @@ -469,6 +471,21 @@ def sample_inputs(self, device, dtype, requires_grad=False): handles_complex_extremals=False), ] +if TEST_SCIPY: + op_db_scipy_reference = [ + UnaryUfuncInfo('digamma', + ref=scipy.special.digamma, + decorators=(precisionOverride({torch.float16: 5e-1}),), + # 'expit' not supported for the input types + skips=(SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + promotes_integers_to_float=True) + ] + op_db = op_db + op_db_scipy_reference + # Common operator groupings unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)] From 085cbfe8d820282eb81cfd66d2bd2913ef28952c Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 20 Nov 2020 02:58:25 -0600 Subject: [PATCH 03/26] match negative integer input and zero input against scipy --- aten/src/ATen/native/Math.h | 8 ++++---- aten/src/ATen/native/cuda/Math.cuh | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 04c6925933a3..1589f5f51b0d 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -279,13 +279,13 @@ static inline float trigamma(float x) { static inline double calc_digamma(double x) { static double PSI_10 = 2.25175258906672110764; if (x == 0) { - return INFINITY; + return -INFINITY; } int x_is_integer = x == floor(x); if (x < 0) { if (x_is_integer) { - return INFINITY; + return NAN; } return calc_digamma(1 - x) - M_PI / tan(M_PI * x); } @@ -326,13 +326,13 @@ static inline double calc_digamma(double x) { static inline float calc_digamma(float x) { static float PSI_10 = 2.25175258906672110764f; if (x == 0) { - return INFINITY; + return -INFINITY; } int x_is_integer = x == floorf(x); if (x < 0) { if (x_is_integer) { - return INFINITY; + return NAN; } // Avoid rounding errors for `tan`'s input. // Those make a big difference at extreme values. diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 1daba76c9446..6ae3ea989c1a 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -108,14 +108,14 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { accscalar_t x = static_cast(in); if (x == 0) { - return static_cast(INFINITY); + return static_cast(-INFINITY); } bool x_is_integer = x == ::floor(x); accscalar_t result = 0; if (x < 0) { if (x_is_integer) { - return static_cast(INFINITY); + return static_cast(NAN); } // Rounding errors in tan's input can really affect the output // for extreme values, so we always perform this computation in double. From 8a2e4bbb4b1e5326c6ec9f7278e6d65556d3b1e3 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 20 Nov 2020 08:50:52 -0600 Subject: [PATCH 04/26] undo changes on polygamma --- aten/src/ATen/native/UnaryOps.cpp | 10 ++++------ aten/src/ATen/native/cuda/UnaryGammaKernels.cu | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 8a34d6649d0a..2b88c8fde7c6 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -602,18 +602,16 @@ Tensor& clip_(Tensor& self, optional min, optional max) { } Tensor polygamma(int64_t n, const Tensor& self) { - TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); - Tensor result; - auto iter = TensorIterator::unary_float_op(result, self); - polygamma_stub(iter.device_type(), iter, n); - return iter.output(); + Tensor result = at::empty({0}, self.options()); + at::polygamma_out(result, n, self); + return result; } Tensor& polygamma_(Tensor& self, int64_t n) { return at::polygamma_out(self, n, self); } Tensor& polygamma_out(Tensor& result, int64_t n, const Tensor& self) { TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); - auto iter = TensorIterator::unary_float_op(result, self); + auto iter = TensorIterator::unary_op(result, self); polygamma_stub(iter.device_type(), iter, n); return result; } diff --git a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu index d451aae0a444..97dbeefccc77 100644 --- a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu @@ -19,7 +19,7 @@ void digamma_kernel_cuda(TensorIterator& iter) { } void trigamma_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "trigamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "trigamma_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_trigamma(a); }); @@ -32,7 +32,7 @@ void polygamma_kernel_cuda(TensorIterator& iter, int64_t n) { } else if (n == 1) { trigamma_kernel_cuda(iter); } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "polygamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "polygamma_cuda", [&]() { gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_polygamma(int(n), a); }); From 65805b838ec2294e1beee9fb0a69a0076a01cdf6 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 20 Nov 2020 09:01:06 -0600 Subject: [PATCH 05/26] match scipy behaviour at 0 --- aten/src/ATen/native/Math.h | 4 ++-- aten/src/ATen/native/cuda/Math.cuh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 1589f5f51b0d..c8cbe8424a5a 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -279,7 +279,7 @@ static inline float trigamma(float x) { static inline double calc_digamma(double x) { static double PSI_10 = 2.25175258906672110764; if (x == 0) { - return -INFINITY; + return std::copysign(INFINITY, -x); } int x_is_integer = x == floor(x); @@ -326,7 +326,7 @@ static inline double calc_digamma(double x) { static inline float calc_digamma(float x) { static float PSI_10 = 2.25175258906672110764f; if (x == 0) { - return -INFINITY; + return std::copysign(INFINITY, -x); } int x_is_integer = x == floorf(x); diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 6ae3ea989c1a..15389a1d6308 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -108,7 +108,7 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { accscalar_t x = static_cast(in); if (x == 0) { - return static_cast(-INFINITY); + return std::copysign(static_cast(INFINITY), -x); } bool x_is_integer = x == ::floor(x); From 034a989a59e78f595ac619c502e41000c1f8deb4 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 20 Nov 2020 22:50:26 -0600 Subject: [PATCH 06/26] add test with special values based on SciPy test --- test/test_unary_ufuncs.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index ebe66e1d181b..58230cf69efd 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -9,7 +9,7 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, torch_to_numpy_dtype_dict, suppress_warnings, - TEST_NUMPY, IS_MACOS, make_tensor) + TEST_NUMPY, IS_MACOS, make_tensor, TEST_SCIPY) from torch.testing._internal.common_methods_invocations import \ (unary_ufuncs) from torch.testing._internal.common_device_type import \ @@ -20,6 +20,9 @@ if TEST_NUMPY: import numpy as np +if TEST_SCIPY: + import scipy.special + # Tests for unary "universal functions (ufuncs)" that accept a single # tensor and have common properties like: # - they are elementwise functions @@ -490,6 +493,24 @@ def test_sqrt_complex_edge_values(self, device, dtype): x = torch.tensor(-1.0000e+20 - 4988429.2000j, dtype=dtype, device=device) self.compare_with_numpy(torch.sqrt, np.sqrt, x) + @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") + @dtypes(torch.float, torch.double) + def test_digamma_special(self, device, dtype): + # Based on SciPy test for the following special values. + # Reference: + # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22 + euler = 0.57721566490153286 + dataset = [(0., -0.), + (1, -euler), + (0.5, -2 * math.log(2) - euler), + (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler), + (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler), + (1 / 6, -math.pi * math.sqrt(3) / 2 - 2 * math.log(2) - 3 * math.log(3) / 2 - euler), + (1 / 8, -math.pi / 2 - 4 * math.log(2) - + (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) / math.sqrt(2) - euler)] + x = torch.tensor(dataset, device=device, dtype=dtype) + self.compare_with_numpy(torch.digamma, scipy.special.digamma, x) + instantiate_device_type_tests(TestUnaryUfuncs, globals()) if __name__ == '__main__': From a90952cd7ceced7d3716c79c020d968b9df5f207 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 28 Nov 2020 00:27:29 -0600 Subject: [PATCH 07/26] add comments for special value return --- aten/src/ATen/native/Math.h | 10 ++++++++++ aten/src/ATen/native/cuda/Math.cuh | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index c8cbe8424a5a..4baa51e296b8 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -277,14 +277,19 @@ static inline float trigamma(float x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline double calc_digamma(double x) { + // [CPP Standard Reference] https://en.cppreference.com/w/cpp/numeric/math/tgamma static double PSI_10 = 2.25175258906672110764; if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned return std::copysign(INFINITY, -x); } int x_is_integer = x == floor(x); if (x < 0) { if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned return NAN; } return calc_digamma(1 - x) - M_PI / tan(M_PI * x); @@ -324,14 +329,19 @@ static inline double calc_digamma(double x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline float calc_digamma(float x) { + // See [CPP Standard Reference] static float PSI_10 = 2.25175258906672110764f; if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned return std::copysign(INFINITY, -x); } int x_is_integer = x == floorf(x); if (x < 0) { if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned return NAN; } // Avoid rounding errors for `tan`'s input. diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 15389a1d6308..89e99af945ef 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -93,6 +93,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { */ template static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { + // [CPP Standard Reference] https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type; static const double PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; @@ -108,6 +109,8 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { accscalar_t x = static_cast(in); if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned return std::copysign(static_cast(INFINITY), -x); } @@ -115,6 +118,8 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { accscalar_t result = 0; if (x < 0) { if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned return static_cast(NAN); } // Rounding errors in tan's input can really affect the output From 7aaa3d57c74da5857212110e223c3607ea3cbfc7 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 28 Nov 2020 00:27:48 -0600 Subject: [PATCH 08/26] update docs --- torch/_torch_docs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index e7978edc06d6..f6f2a31d08b5 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2515,6 +2515,12 @@ def merge_dicts(*dicts): Keyword args: {out} +.. note:: This function is similar to SciPy's `scipy.special.digamma`. + +.. note:: From version 1.8 onwards, the digamma function returns `NaN` for non-positive integers, + while for `0`, it returns `-Inf` to be consistent with SciPy and C++ Standard. + Prior to version 1.8, the function would return `NaN` for non-positive integers and `0`. + Example:: >>> a = torch.tensor([1, 0.5]) From 63d322d4303ce1f425d3f550ddf18e472c040238 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 8 Dec 2020 06:38:29 -0600 Subject: [PATCH 09/26] remove MathTestMeta entry for digamma --- test/test_unary_ufuncs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index aea582a08a10..986fa2d8557d 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1809,9 +1809,6 @@ def _medium_2d(dtype, device): _TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma', refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], ref_backend='scipy', rtol=0.0008, atol=1e-5), - _TorchMathTestMeta('digamma', - input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy', - replace_inf_with_nan=True), _TorchMathTestMeta('abs', input_fn=_medium_2d, dtypes=_types_no_half, rtol=0., atol=0.), _TorchMathTestMeta('logit', ref_backend='scipy')] From 401d44375ab429bb78a8e77f879dd3011a0b2e98 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 8 Dec 2020 06:41:15 -0600 Subject: [PATCH 10/26] test digamma to use scipy reference operator --- test/test_unary_ufuncs.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 986fa2d8557d..72b0902d1a09 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -514,6 +514,19 @@ def test_digamma_special(self, device, dtype): x = torch.tensor(dataset, device=device, dtype=dtype) self.compare_with_numpy(torch.digamma, scipy.special.digamma, x) + @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") + @dtypes(torch.float, torch.double) + def test_digamma(self, device, dtype): + nan = float('nan') + tensor = torch.randn(10, 10, 10, dtype=dtype, device=device) + self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) + + # Tests pole behavior + tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, + -100.99999994, -1931.99999994, 0.000000111, + -0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device) + self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) + # TODO opinfo mvlgamma @unittest.skipIf(not TEST_SCIPY, "Scipy not found") def test_mvlgamma(self, device): @@ -1203,30 +1216,6 @@ def test_polygamma(self, device, dtype): torch.autograd.gradcheck(lambda x: x.polygamma(n), cpu_tensor) - # Note: fails when using float tensors - # TODO: update this test to just compare against NumPy - @onlyCUDA - @dtypes(torch.double) - def test_digamma(self, device, dtype): - cpu_tensor = torch.randn(10, 10, 10, dtype=dtype) - device_tensor = cpu_tensor.to(device) - zeros = torch.zeros(10, 10, 10, dtype=dtype) - cpu_out = cpu_tensor.digamma() - device_out = device_tensor.digamma() - norm_errors = (device_out - cpu_out.to(device)) / device_out - self.assertEqual(norm_errors, zeros) - - # Tests pole behavior - cpu_tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, - -100.99999994, -1931.99999994, 0.000000111, - -0.000000111, 0, -1, -2, -931], dtype=dtype) - expected_errors = torch.tensor([0, 0, 0, 0, 0, 0, 0, nan, nan, nan, nan], dtype=dtype) - device_tensor = cpu_tensor.to(device) - cpu_out = cpu_tensor.digamma() - device_out = device_tensor.digamma() - norm_errors = (device_out - cpu_out.to(device)) / device_out - self.assertEqual(norm_errors, expected_errors) - # TODO: update to compare against NumPy by rationalizing with OpInfo @onlyCUDA @dtypes(torch.float, torch.double) From 630ebad57d240cf88aec66674e918aa23b9cc587 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 8 Dec 2020 06:43:29 -0600 Subject: [PATCH 11/26] remove stray space --- test/test_unary_ufuncs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 72b0902d1a09..0e7f84a47459 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -24,7 +24,6 @@ if TEST_SCIPY: import scipy - # Tests for unary "universal functions (ufuncs)" that accept a single # tensor and have common properties like: # - they are elementwise functions From d0115619bac335f1785283b37feff1393c813dc7 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 8 Dec 2020 06:49:16 -0600 Subject: [PATCH 12/26] handle failure with test_variant_consistency_jit --- torch/testing/_internal/common_methods_invocations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8e1da1effe82..3b77d8091c98 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -858,7 +858,9 @@ def reference_sigmoid(x): ref=scipy.special.digamma, decorators=(precisionOverride({torch.float16: 5e-1}),), skips=(SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', - device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),), + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device_type='cuda', dtypes=[torch.float16]),), dtypes=all_types_and(torch.bool), dtypesIfCPU=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool, torch.half), From 18aacb1a81e41cfa62963e9ca4209e112e0217cc Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 8 Dec 2020 06:55:07 -0600 Subject: [PATCH 13/26] remove problematic value --- test/test_unary_ufuncs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 0e7f84a47459..2c9a23b7d4c3 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -521,8 +521,10 @@ def test_digamma(self, device, dtype): self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) # Tests pole behavior + # TODO: Add value `-1931.99999994`, to the tensor below when + # https://github.com/pytorch/pytorch/issues/49015 is fixed tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, - -100.99999994, -1931.99999994, 0.000000111, + -100.99999994, 0.000000111, -0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device) self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) From eb2e176ea0f49a7041e4342f70d6a643331fe356 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 17 Dec 2020 04:38:37 -0600 Subject: [PATCH 14/26] update Note name and use trunc --- aten/src/ATen/native/Math.h | 8 ++++---- aten/src/ATen/native/cuda/Math.cuh | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 4baa51e296b8..33f997a6ad5e 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -277,7 +277,7 @@ static inline float trigamma(float x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline double calc_digamma(double x) { - // [CPP Standard Reference] https://en.cppreference.com/w/cpp/numeric/math/tgamma + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma static double PSI_10 = 2.25175258906672110764; if (x == 0) { // As per C++ standard for gamma related functions and SciPy, @@ -285,7 +285,7 @@ static inline double calc_digamma(double x) { return std::copysign(INFINITY, -x); } - int x_is_integer = x == floor(x); + int x_is_integer = x == trunc(x); if (x < 0) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, @@ -329,7 +329,7 @@ static inline double calc_digamma(double x) { * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline float calc_digamma(float x) { - // See [CPP Standard Reference] + // See [C++ Standard Reference: Gamma Function] static float PSI_10 = 2.25175258906672110764f; if (x == 0) { // As per C++ standard for gamma related functions and SciPy, @@ -337,7 +337,7 @@ static inline float calc_digamma(float x) { return std::copysign(INFINITY, -x); } - int x_is_integer = x == floorf(x); + int x_is_integer = x == truncf(x); if (x < 0) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 89e99af945ef..17c30cd00ea7 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -93,7 +93,7 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { */ template static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { - // [CPP Standard Reference] https://en.cppreference.com/w/cpp/numeric/math/tgamma + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type; static const double PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; @@ -114,7 +114,7 @@ static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { return std::copysign(static_cast(INFINITY), -x); } - bool x_is_integer = x == ::floor(x); + bool x_is_integer = x == ::trunc(x); accscalar_t result = 0; if (x < 0) { if (x_is_integer) { From 7684f3bf96636410a47ee3aafdc4f3b83babac1c Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 17 Dec 2020 04:38:52 -0600 Subject: [PATCH 15/26] use iter.common_dtype --- aten/src/ATen/native/cpu/UnaryOpsKernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index f7c4f9c34613..1809ffff714e 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -345,7 +345,7 @@ static void atanh_kernel(TensorIterator& iter) { } static void digamma_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "digamma", [&]() { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "digamma", [&]() { cpu_kernel( iter, [=](scalar_t a) -> scalar_t { return calc_digamma(a); }); From 6ecf8dfa8e056a656de6a6a62974e4c9cfce45f8 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 17 Dec 2020 04:39:23 -0600 Subject: [PATCH 16/26] remove redundant test --- test/test_unary_ufuncs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 14ab54bd33e6..8fd9ecf8b768 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -516,10 +516,6 @@ def test_digamma_special(self, device, dtype): @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") @dtypes(torch.float, torch.double) def test_digamma(self, device, dtype): - nan = float('nan') - tensor = torch.randn(10, 10, 10, dtype=dtype, device=device) - self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) - # Tests pole behavior # TODO: Add value `-1931.99999994`, to the tensor below when # https://github.com/pytorch/pytorch/issues/49015 is fixed From 644bc740258ce4c91b5044cb87c059937eabdd39 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 17 Dec 2020 04:39:54 -0600 Subject: [PATCH 17/26] update the skips --- torch/testing/_internal/common_methods_invocations.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4af8678ef70d..3bc26b40764f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -866,9 +866,7 @@ def reference_sigmoid(x): UnaryUfuncInfo('digamma', ref=scipy.special.digamma, decorators=(precisionOverride({torch.float16: 5e-1}),), - skips=(SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', - device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - SkipInfo('TestCommon', 'test_variant_consistency_jit', + skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]),), dtypes=all_types_and(torch.bool), dtypesIfCPU=all_types_and(torch.bool), From ddbcb6c3d742ede53622597cab344d8a741f25ba Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 17 Dec 2020 04:43:28 -0600 Subject: [PATCH 18/26] update docs --- torch/_torch_docs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index f97903b11477..b57f3c3d43bf 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2511,9 +2511,9 @@ def merge_dicts(*dicts): .. note:: This function is similar to SciPy's `scipy.special.digamma`. -.. note:: From version 1.8 onwards, the digamma function returns `NaN` for non-positive integers, - while for `0`, it returns `-Inf` to be consistent with SciPy and C++ Standard. - Prior to version 1.8, the function would return `NaN` for non-positive integers and `0`. +.. note:: From PyTorch 1.8 onwards, the digamma function returns `NaN` for non-positive integers, + while for `0`, it returns `-Inf` to be consistent with SciPy and the C++ Standard. + Prior to PyTorch 1.8, the function would return `NaN` for non-positive integers and `0`. Example:: From d84bd1f6b025a5a8bc13577ee4f50c5b56cf520f Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 17 Dec 2020 05:19:13 -0600 Subject: [PATCH 19/26] update docs related to term non-positive --- torch/_torch_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b57f3c3d43bf..d9c236f3289a 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2511,9 +2511,9 @@ def merge_dicts(*dicts): .. note:: This function is similar to SciPy's `scipy.special.digamma`. -.. note:: From PyTorch 1.8 onwards, the digamma function returns `NaN` for non-positive integers, +.. note:: From PyTorch 1.8 onwards, the digamma function returns `NaN` for negative integers, while for `0`, it returns `-Inf` to be consistent with SciPy and the C++ Standard. - Prior to PyTorch 1.8, the function would return `NaN` for non-positive integers and `0`. + Prior to PyTorch 1.8, the function would return `NaN` for non-positive integers. Example:: From 4c44a6de08e79e60adc270f5da92512c2bf8c499 Mon Sep 17 00:00:00 2001 From: Mike Ruberry Date: Mon, 21 Dec 2020 04:03:07 -0800 Subject: [PATCH 20/26] wording revision --- torch/_torch_docs.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index d9c236f3289a..84fb3e504c67 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2494,8 +2494,7 @@ def merge_dicts(*dicts): [ 1.0500, 0.7336, -0.3836, -1.1015]]]) """.format(**common_args)) -add_docstr(torch.digamma, - r""" +add_docstr(torch.digamma, r""" digamma(input, *, out=None) -> Tensor Computes the logarithmic derivative of the gamma function on `input`. @@ -2511,9 +2510,8 @@ def merge_dicts(*dicts): .. note:: This function is similar to SciPy's `scipy.special.digamma`. -.. note:: From PyTorch 1.8 onwards, the digamma function returns `NaN` for negative integers, - while for `0`, it returns `-Inf` to be consistent with SciPy and the C++ Standard. - Prior to PyTorch 1.8, the function would return `NaN` for non-positive integers. +.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`. + Previously it returned `NaN` for `0`. Example:: From 332c48e09dffe4c6e2eedb1c962fb380ecbd1d71 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 21 Dec 2020 06:19:37 -0600 Subject: [PATCH 21/26] remove digamma from tensor_op_tests --- test/test_torch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index b4d9ad6f23c0..fd97cab65686 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6927,7 +6927,6 @@ def inner(self, device, dtype): ('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]), - ('digamma', 'op', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e0, _float_types_no_half), ] # Creates and decorates a generic test and adds it to the class. From 83e890a63c01f465b5e0668c180066101b7c634f Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 21 Dec 2020 06:21:40 -0600 Subject: [PATCH 22/26] remove skip --- torch/testing/_internal/common_methods_invocations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0becf92ff21b..9c403b346dd3 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1029,8 +1029,6 @@ def reference_sigmoid(x): UnaryUfuncInfo('digamma', ref=scipy.special.digamma, decorators=(precisionOverride({torch.float16: 5e-1}),), - skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit', - device_type='cuda', dtypes=[torch.float16]),), dtypes=all_types_and(torch.bool), dtypesIfCPU=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool, torch.half), From a3375c22a78a6e57925fa009e43acc3792f9af09 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 21 Dec 2020 23:31:35 -0600 Subject: [PATCH 23/26] add float16 skip --- torch/testing/_internal/common_methods_invocations.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5eeb416c7bea..87c94c24ab25 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1095,6 +1095,12 @@ def reference_sigmoid(x): dtypes=all_types_and(torch.bool), dtypesIfCPU=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool, torch.half), + skips=( + # In some cases, output is NaN (for input close to + # negative integers) especially due to reduced precision + # in float16 and NaN's can't be tested for equality. + SkipInfo('TestCommon', 'test_variant_consistency_jit', + device='cuda', dtypes=[torch.float16]),), promotes_integers_to_float=True) ] op_db = op_db + op_db_scipy_reference From 86a33a4ee70092aff5154d0cbd8e8264a17a08ee Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 22 Dec 2020 00:32:00 -0600 Subject: [PATCH 24/26] fix argument name --- torch/testing/_internal/common_methods_invocations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 87c94c24ab25..5d802e715b34 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1100,7 +1100,7 @@ def reference_sigmoid(x): # negative integers) especially due to reduced precision # in float16 and NaN's can't be tested for equality. SkipInfo('TestCommon', 'test_variant_consistency_jit', - device='cuda', dtypes=[torch.float16]),), + device_type='cuda', dtypes=[torch.float16]),), promotes_integers_to_float=True) ] op_db = op_db + op_db_scipy_reference From e16fd9c51eb121079bdbae018c338e2975f5be37 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 22 Dec 2020 12:08:41 -0600 Subject: [PATCH 25/26] make lint happy : int -> bool x_is_integer --- aten/src/ATen/native/Math.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 33f997a6ad5e..6cd0464de921 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -285,7 +285,7 @@ static inline double calc_digamma(double x) { return std::copysign(INFINITY, -x); } - int x_is_integer = x == trunc(x); + bool x_is_integer = x == trunc(x); if (x < 0) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, @@ -337,7 +337,7 @@ static inline float calc_digamma(float x) { return std::copysign(INFINITY, -x); } - int x_is_integer = x == truncf(x); + bool x_is_integer = x == truncf(x); if (x < 0) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, From 86bf59beb29d26fd412ed7b3534f01d8b808dba3 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 22 Dec 2020 23:26:10 -0600 Subject: [PATCH 26/26] fix stray merge --- torch/testing/_internal/common_methods_invocations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ee3a889037f0..cdeadd1ca6e5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1118,7 +1118,7 @@ def reference_sigmoid(x): # in float16 and NaN's can't be tested for equality. SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]),), - promotes_integers_to_float=True) + promotes_integers_to_float=True), OpInfo('xlogy', dtypes=all_types_and(torch.bool), dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),