Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Reference #50345 `zeta` was already present in the codebase to support computation of `polygamma`. However, `zeta` only had `double(double, double)` signature **for CPU** before the PR (which meant that computation `polygamma` were always upcasted to `double` for zeta part). With this PR, float computations will take place in float and double in double. Have also refactored the code and moved the duplicate code from `Math.cuh` to `Math.h` **Note**: For scipy, q is optional, and if it is `None`, it defaults `1` which corresponds to Reimann-Zeta. However, for `torch.specia.zeta`, I made it mandatory cause for me it feels odd without `q` this is Reimann-Zeta and with `q` it is the general Hurwitz Zeta. I think sticking to just general made more sense as passing `1` for q sounds trivial. Verify: * [x] Docs https://14234587-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.zeta Pull Request resolved: #59623 Reviewed By: ngimel Differential Revision: D29348269 Pulled By: mruberry fbshipit-source-id: a3f9ebe1f7724dbe66de2b391afb9da1cfc3e4bb
- Loading branch information
1 parent
26cdec6
commit dfd2edc
Showing
18 changed files
with
296 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
#include <c10/util/Half.h> | ||
#include <c10/util/MathConstants.h> | ||
#include <c10/util/math_compat.h> | ||
#include <ATen/AccumulateType.h> | ||
|
||
|
||
/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. | ||
|
@@ -148,9 +149,14 @@ Date: February 1996 | |
* This function is derived from the implementation of the zeta function in the Cephes Math Library. | ||
* See note [3-Clause BSD License for the Cephes Math Library]. | ||
*/ | ||
static inline double zeta(double x, double q) { | ||
static double MACHEP = 1.11022302462515654042E-16; | ||
static double A[] = { | ||
template <typename scalar_t, bool is_cuda=false> | ||
C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) { | ||
using acc_t = at::acc_type<scalar_t, is_cuda>; | ||
const acc_t MACHEP = acc_t{1.11022302462515654042E-16}; | ||
constexpr acc_t zero = acc_t{0.0}; | ||
constexpr acc_t half = acc_t{0.5}; | ||
constexpr acc_t one = acc_t{1.0}; | ||
static const acc_t A[] = { | ||
12.0, | ||
-720.0, | ||
30240.0, | ||
|
@@ -166,58 +172,58 @@ static inline double zeta(double x, double q) { | |
}; | ||
|
||
int i = 0; | ||
double a, b, k, s, t, w; | ||
if (x == 1.0) { | ||
return INFINITY; | ||
acc_t a, b, k, s, t, w; | ||
if (x == one) { | ||
return std::numeric_limits<scalar_t>::infinity(); | ||
} | ||
|
||
if (x < 1.0) { | ||
return std::numeric_limits<double>::quiet_NaN(); | ||
if (x < one) { | ||
return std::numeric_limits<scalar_t>::quiet_NaN(); | ||
} | ||
|
||
if (q <= 0.0) { | ||
if (q == floor(q)) { | ||
return INFINITY; | ||
if (q <= zero) { | ||
if (q == ::floor(q)) { | ||
return std::numeric_limits<scalar_t>::infinity(); | ||
} | ||
if (x != floor(x)) { | ||
return std::numeric_limits<double>::quiet_NaN(); | ||
if (x != ::floor(x)) { | ||
return std::numeric_limits<scalar_t>::quiet_NaN(); | ||
} | ||
} | ||
|
||
s = std::pow(q, -x); | ||
s = ::pow(q, -x); | ||
a = q; | ||
i = 0; | ||
b = 0.0; | ||
while ((i < 9) || (a <= 9.0)) { | ||
b = zero; | ||
while ((i < 9) || (a <= acc_t{9.0})) { | ||
i += 1; | ||
a += 1.0; | ||
b = std::pow(a, -x); | ||
a += one; | ||
b = ::pow(a, -x); | ||
s += b; | ||
if ((-MACHEP * s < b) && (b < MACHEP * s)) { | ||
return s; | ||
return static_cast<scalar_t>(s); | ||
} | ||
}; | ||
|
||
w = a; | ||
s += b * w / (x - 1.0); | ||
s -= 0.5 * b; | ||
a = 1.0; | ||
k = 0.0; | ||
s += b * w / (x - one); | ||
s -= half * b; | ||
a = one; | ||
k = zero; | ||
for (int i = 0; i < 12; i++) { | ||
a *= x + k; | ||
b /= w; | ||
t = a * b / A[i]; | ||
s = s + t; | ||
t = std::abs(t / s); | ||
t = ::abs(t / s); | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
kshitij12345
Author
Collaborator
|
||
if (t < MACHEP) { | ||
return s; | ||
return static_cast<scalar_t>(s); | ||
} | ||
k += 1.0; | ||
k += one; | ||
a *= x + k; | ||
b /= w; | ||
k += 1.0; | ||
k += one; | ||
} | ||
return s; | ||
return static_cast<scalar_t>(s); | ||
} | ||
|
||
/* | ||
|
@@ -397,16 +403,12 @@ static inline float calc_digamma(float x) { | |
return result + logf(x) - (0.5f / x) - y; | ||
} | ||
|
||
static inline double calc_polygamma(int64_t n, double x) { | ||
// already blocked if n <= 1 | ||
return ((n % 2) ? 1.0 : -1.0) * std::exp(lgamma(double(n) + 1.0)) * | ||
zeta(double(n + 1), x); | ||
} | ||
|
||
static inline float calc_polygamma(int64_t n, float x) { | ||
template <typename scalar_t, bool is_cuda=false> | ||
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) { | ||
// already blocked if n <= 1 | ||
return ((n % 2) ? 1.0f : -1.0f) * std::exp(lgamma(double(n) + 1.0)) * | ||
zeta(double(n + 1), x); | ||
return ((n % 2) ? 1.0 : -1.0) * | ||
::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) * | ||
zeta<scalar_t, is_cuda>(static_cast<scalar_t>(n + 1), x); | ||
} | ||
|
||
// regularized lower incomplete gamma | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Hello @kshitij12345, if
t
is-nan
here, what'd be a good way to mitigate this situation, so that the ASAN CI check won't complain, as in #60444?Thanks!