Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

To fix inconsistency of digamma with SciPy #56689

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 14 additions & 4 deletions aten/src/ATen/native/Math.h
Expand Up @@ -292,7 +292,13 @@ static inline double calc_digamma(double x) {
// If the argument is a negative integer, NaN is returned
return std::numeric_limits<double>::quiet_NaN();
}
return calc_digamma(1 - x) - c10::pi<double> / tan(c10::pi<double> * x);
// Extracts the fractional part of x as r, since tan(pi * r) is more numerically
// accurate than tan(pi * x). While these operations are mathematically equivalent
// since both x and r are in radians and tan() has a periodicity of pi, in practice
// the computation of pi * x is a source of error (when |x| > 1).
double q, r;
r = std::modf(x, &q);
return calc_digamma(1 - x) - c10::pi<double> / tan(c10::pi<double> * r);
}

// Push x to be >= 10
Expand Down Expand Up @@ -344,9 +350,13 @@ static inline float calc_digamma(float x) {
// If the argument is a negative integer, NaN is returned
return std::numeric_limits<float>::quiet_NaN();
}
// Avoid rounding errors for `tan`'s input.
// Those make a big difference at extreme values.
float pi_over_tan_pi_x = (float)(c10::pi<double> / tan(c10::pi<double> * (double)x));
// Extracts the fractional part of x as r, since tan(pi * r) is more numerically
// accurate than tan(pi * x). While these operations are mathematically equivalent
// since both x and r are in radians and tan() has a periodicity of pi, in practice
// the computation of pi * x is a source of error (when |x| > 1).
double q, r;
r = std::modf(x, &q);
float pi_over_tan_pi_x = (float)(c10::pi<double> / tan(c10::pi<double> * r));
return calc_digamma(1 - x) - pi_over_tan_pi_x;
}

Expand Down
10 changes: 7 additions & 3 deletions aten/src/ATen/native/cuda/Math.cuh
Expand Up @@ -122,9 +122,13 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// If the argument is a negative integer, NaN is returned
return static_cast<scalar_t>(NAN);
}
// Rounding errors in tan's input can really affect the output
// for extreme values, so we always perform this computation in double.
result = static_cast<accscalar_t>(- PI_f64 / ::tan(PI_f64 * static_cast<double>(x)));
// Extracts the fractional part of x as r, since tan(pi * r) is more numerically
// accurate than tan(pi * x). While these operations are mathematically equivalent
// since both x and r are in radians and tan() has a periodicity of pi, in practice
// the computation of pi * x is a source of error (when |x| > 1).
double q, r;
r = ::modf(static_cast<double>(x), &q);
result = static_cast<accscalar_t>(- PI_f64 / ::tan(PI_f64 * r));
x = 1 - x;
}

Expand Down
4 changes: 1 addition & 3 deletions test/test_unary_ufuncs.py
Expand Up @@ -565,10 +565,8 @@ def test_digamma_special(self, device, dtype):
@dtypes(torch.float, torch.double)
def test_digamma(self, device, dtype):
# Tests pole behavior
# TODO: Add value `-1931.99999994`, to the tensor below when
# https://github.com/pytorch/pytorch/issues/49015 is fixed
tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111,
-100.99999994, 0.000000111,
-100.99999994, 0.000000111, -1931.99999994,
-0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device)
self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor)

Expand Down