Skip to content

Commit

Permalink
[special] add zeta (#59623)
Browse files Browse the repository at this point in the history
Summary:
Reference #50345

`zeta` was already present in the codebase to support computation of `polygamma`.

However, `zeta` only had `double(double, double)` signature **for CPU** before the PR (which meant that computation `polygamma` were always upcasted to `double` for zeta part).

With this PR, float computations will take place in float and double in double.

Have also refactored the code and moved the duplicate code from `Math.cuh` to `Math.h`

**Note**: For scipy, q is optional, and if it is `None`, it defaults `1` which corresponds to Reimann-Zeta. However, for `torch.specia.zeta`, I made it mandatory cause for me it feels odd without `q` this is Reimann-Zeta and with `q` it is the general Hurwitz Zeta. I think sticking to just general made more sense as passing `1` for q sounds trivial.

Verify:
* [x] Docs https://14234587-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.zeta

Pull Request resolved: #59623

Reviewed By: ngimel

Differential Revision: D29348269

Pulled By: mruberry

fbshipit-source-id: a3f9ebe1f7724dbe66de2b391afb9da1cfc3e4bb
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Jun 24, 2021
1 parent 26cdec6 commit dfd2edc
Show file tree
Hide file tree
Showing 18 changed files with 296 additions and 126 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -351,6 +351,7 @@ namespace c10 {
_(aten, special_i0e) \
_(aten, special_i1) \
_(aten, special_i1e) \
_(aten, special_zeta) \
_(aten, has_torch_function) \
_(aten, hardswish) \
_(aten, hardswish_) \
Expand Down
25 changes: 25 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -57,6 +57,10 @@ TORCH_META_FUNC(special_xlog1py) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_float_op(maybe_get_output(), self, other);
}

TORCH_META_FUNC(special_zeta) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_float_op(maybe_get_output(), self, other);
}

TORCH_META_FUNC2(copysign, Tensor) (
const Tensor& self, const Tensor& other
) {
Expand Down Expand Up @@ -221,6 +225,7 @@ DEFINE_DISPATCH(copysign_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(xlogy_stub);
DEFINE_DISPATCH(xlog1py_stub);
DEFINE_DISPATCH(zeta_stub);

TORCH_IMPL_FUNC(add_out) (
const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
Expand Down Expand Up @@ -262,6 +267,10 @@ TORCH_IMPL_FUNC(special_xlog1py_out) (const Tensor& self, const Tensor& other, c
xlog1py_stub(device_type(), *this);
}

TORCH_IMPL_FUNC(special_zeta_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
zeta_stub(device_type(), *this);
}

#define CREATE_BINARY_TORCH_IMPL_FUNC(func_out, func_stub) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& other, const Tensor& result) { \
func_stub(device_type(), *this); \
Expand Down Expand Up @@ -297,6 +306,22 @@ Tensor& special_xlog1py_out(const Tensor& self, const Scalar& other, Tensor& res
return at::special_xlog1py_out(result, self, wrapped_scalar_tensor(other));
}

Tensor special_zeta(const Scalar& x, const Tensor& y) {
return at::special_zeta(wrapped_scalar_tensor(x), y);
}

Tensor special_zeta(const Tensor& x, const Scalar& y) {
return at::special_zeta(x, wrapped_scalar_tensor(y));
}

Tensor& special_zeta_out(const Scalar& self, const Tensor& other, Tensor& result) {
return at::special_zeta_out(result, wrapped_scalar_tensor(self), other);
}

Tensor& special_zeta_out(const Tensor& self, const Scalar& other, Tensor& result) {
return at::special_zeta_out(result, self, wrapped_scalar_tensor(other));
}

TORCH_IMPL_FUNC(atan2_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
atan2_stub(device_type(), *this);
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -93,5 +93,6 @@ DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
DECLARE_DISPATCH(binary_fn, xlogy_stub);
DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
DECLARE_DISPATCH(structured_binary_fn, zeta_stub);

}} // namespace at::native
76 changes: 39 additions & 37 deletions aten/src/ATen/native/Math.h
Expand Up @@ -10,6 +10,7 @@
#include <c10/util/Half.h>
#include <c10/util/MathConstants.h>
#include <c10/util/math_compat.h>
#include <ATen/AccumulateType.h>


/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
Expand Down Expand Up @@ -148,9 +149,14 @@ Date: February 1996
* This function is derived from the implementation of the zeta function in the Cephes Math Library.
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static inline double zeta(double x, double q) {
static double MACHEP = 1.11022302462515654042E-16;
static double A[] = {
template <typename scalar_t, bool is_cuda=false>
C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) {
using acc_t = at::acc_type<scalar_t, is_cuda>;
const acc_t MACHEP = acc_t{1.11022302462515654042E-16};
constexpr acc_t zero = acc_t{0.0};
constexpr acc_t half = acc_t{0.5};
constexpr acc_t one = acc_t{1.0};
static const acc_t A[] = {
12.0,
-720.0,
30240.0,
Expand All @@ -166,58 +172,58 @@ static inline double zeta(double x, double q) {
};

int i = 0;
double a, b, k, s, t, w;
if (x == 1.0) {
return INFINITY;
acc_t a, b, k, s, t, w;
if (x == one) {
return std::numeric_limits<scalar_t>::infinity();
}

if (x < 1.0) {
return std::numeric_limits<double>::quiet_NaN();
if (x < one) {
return std::numeric_limits<scalar_t>::quiet_NaN();
}

if (q <= 0.0) {
if (q == floor(q)) {
return INFINITY;
if (q <= zero) {
if (q == ::floor(q)) {
return std::numeric_limits<scalar_t>::infinity();
}
if (x != floor(x)) {
return std::numeric_limits<double>::quiet_NaN();
if (x != ::floor(x)) {
return std::numeric_limits<scalar_t>::quiet_NaN();
}
}

s = std::pow(q, -x);
s = ::pow(q, -x);
a = q;
i = 0;
b = 0.0;
while ((i < 9) || (a <= 9.0)) {
b = zero;
while ((i < 9) || (a <= acc_t{9.0})) {
i += 1;
a += 1.0;
b = std::pow(a, -x);
a += one;
b = ::pow(a, -x);
s += b;
if ((-MACHEP * s < b) && (b < MACHEP * s)) {
return s;
return static_cast<scalar_t>(s);
}
};

w = a;
s += b * w / (x - 1.0);
s -= 0.5 * b;
a = 1.0;
k = 0.0;
s += b * w / (x - one);
s -= half * b;
a = one;
k = zero;
for (int i = 0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = std::abs(t / s);
t = ::abs(t / s);

This comment has been minimized.

Copy link
@imaginary-person

imaginary-person Jul 9, 2021

Contributor

Hello @kshitij12345, if t is -nan here, what'd be a good way to mitigate this situation, so that the ASAN CI check won't complain, as in #60444?
Thanks!

This comment has been minimized.

Copy link
@kshitij12345

kshitij12345 Jul 9, 2021

Author Collaborator

The correct thing to do will be to use ::fabs.

Reason: ::abs leads to dispatch to int version of abs (with an implicit cast to int) :(

Note: std::abs as previously used to correctly handle floating point computation.

(I wonder why it wasn't caught previously with the ASAN build 🤔)

Thanks for the ping @imaginary-person

Reproducer:

#include <iostream>
#include <cmath>

int main() {
    float x{-NAN};
    std::cout << x << "\n";
    std::cout << "abs:" << ::abs(x) << "\n";
    std::cout << "std::abs:" << std::abs(x) << "\n";
    std::cout << "fabs:" << ::fabs(x) << "\n";
    return 0;
}

Compile:

$ g++ -fsanitize=undefined -g -Wconversion -lubsan ubcheck.cpp 
ubcheck.cpp: In function 'int main()':
ubcheck.cpp:7:34: warning: conversion from 'float' to 'int' may change value [-Wfloat-conversion]
    7 |     std::cout << "abs:" << ::abs(x) << "\n";
      |                                  ^

On running

 $ ./a.out 
-nan
ubcheck.cpp:7:40: runtime error: negation of -2147483648 cannot be represented in type 'int'; cast to an unsigned type to negate this value to itself
abs:-2147483648
std::abs:nan
fabs:nan

Godbolt Link: https://godbolt.org/z/595PW4b5P

if (t < MACHEP) {
return s;
return static_cast<scalar_t>(s);
}
k += 1.0;
k += one;
a *= x + k;
b /= w;
k += 1.0;
k += one;
}
return s;
return static_cast<scalar_t>(s);
}

/*
Expand Down Expand Up @@ -397,16 +403,12 @@ static inline float calc_digamma(float x) {
return result + logf(x) - (0.5f / x) - y;
}

static inline double calc_polygamma(int64_t n, double x) {
// already blocked if n <= 1
return ((n % 2) ? 1.0 : -1.0) * std::exp(lgamma(double(n) + 1.0)) *
zeta(double(n + 1), x);
}

static inline float calc_polygamma(int64_t n, float x) {
template <typename scalar_t, bool is_cuda=false>
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
// already blocked if n <= 1
return ((n % 2) ? 1.0f : -1.0f) * std::exp(lgamma(double(n) + 1.0)) *
zeta(double(n + 1), x);
return ((n % 2) ? 1.0 : -1.0) *
::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) *
zeta<scalar_t, is_cuda>(static_cast<scalar_t>(n + 1), x);
}

// regularized lower incomplete gamma
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -989,6 +989,14 @@ void xlog1py_kernel(TensorIteratorBase& iter) {
});
}

void zeta_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cpu", [&]() {
cpu_kernel(iter, [](scalar_t x, scalar_t q) -> scalar_t {
return zeta(x, q);
});
});
}

} // namespace

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
Expand Down Expand Up @@ -1082,6 +1090,7 @@ REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel);
REGISTER_DISPATCH(zeta_stub, &zeta_kernel);

} // namespace native
} // namespace at
11 changes: 11 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Expand Up @@ -3,6 +3,8 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/Math.h>
#include <ATen/NumericUtils.h>

// NOTE: CUDA on Windows requires that the enclosing function
Expand Down Expand Up @@ -67,11 +69,20 @@ void xlog1py_kernel_cuda(TensorIteratorBase& iter) {
});
}

void zeta_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t q) -> scalar_t {
return zeta<scalar_t, /*is_cuda=*/true>(x, q);
});
});
}

REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(huber_stub, &huber_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel_cuda);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel_cuda);
REGISTER_DISPATCH(zeta_stub, &zeta_kernel_cuda);

// DO NOT ADD ANY NEW KERNELS HERE
// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel.
Expand Down
88 changes: 0 additions & 88 deletions aten/src/ATen/native/cuda/Math.cuh
Expand Up @@ -6,88 +6,6 @@
namespace at {
namespace native {

/*
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
*/
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t zeta(scalar_t _x, scalar_t _q) {
using accscalar_t = at::acc_type<scalar_t, true>;
static const accscalar_t MACHEP = 1.11022302462515654042E-16;
const accscalar_t A[] = {
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9, /*1.307674368e12/691*/
7.47242496e10,
-2.950130727918164224e12, /*1.067062284288e16/3617*/
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
};
accscalar_t x = static_cast<accscalar_t>(_x);
accscalar_t q = static_cast<accscalar_t>(_q);

int i = 0;
accscalar_t a, b, k, s, t, w;
if( x == 1.0 ) {
return static_cast<scalar_t>(INFINITY);
}

if( x < 1.0 ){
std::numeric_limits<scalar_t>::quiet_NaN();
}
bool q_is_integer = q == ::floor(q);

if(q <= 0.0) {
if(q_is_integer) {
return static_cast<scalar_t>(INFINITY);
}
else {
std::numeric_limits<scalar_t>::quiet_NaN();
}
}

s = ::pow(q, -x);
a = q;
i = 0;
b = 0.0;
while ((i < 9) || (a <= 9.0)) {
i += 1;
a += 1.0;
b = ::pow( a, -x );
s += b;
if ((-MACHEP < (b / s)) && ((b / s) < MACHEP)) {
return static_cast<scalar_t>(s);
}
};
w = a;
s += b * w / (x - 1.0);
s -= 0.5 * b;
a = 1.0;
k = 0.0;
for (int i=0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = t / s;
if (t < 0){
t = -t;
}
if ((-MACHEP <t) && (t < MACHEP)){
return static_cast<scalar_t>(s);
}
k += 1.0;
a *= x + k;
b /= w;
k += 1.0;
}
return static_cast<scalar_t>(s);
}

/*
* For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h".
*/
Expand Down Expand Up @@ -177,12 +95,6 @@ static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
return static_cast<scalar_t>(sign * result);
}

template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_polygamma(int n, scalar_t x) {
// already blocked if n <= 1
return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast<scalar_t>(n) + 1.0)) * zeta(static_cast<scalar_t>(n + 1), x);
}

template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {
scalar_t a = ::abs(a_in);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/cuda/UnaryGammaKernels.cu
Expand Up @@ -7,6 +7,7 @@
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/Math.h>

namespace at { namespace native {

Expand Down Expand Up @@ -34,7 +35,7 @@ void polygamma_kernel_cuda(TensorIteratorBase& iter, int64_t n) {
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "polygamma_cuda", [&]() {
gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
return calc_polygamma(int(n), a);
return calc_polygamma<scalar_t, /*is_cuda=*/true>(int(n), a);
});
});
}
Expand Down

0 comments on commit dfd2edc

Please sign in to comment.