Skip to content

Commit

Permalink
Add inf special cases to gamma.c function (pymc-devs#634)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyoshino committed Feb 14, 2024
1 parent 453fb4d commit d8868cc
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
10 changes: 10 additions & 0 deletions pytensor/scalar/c_code/gamma.c
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ DEVICE double GammaP (double n, double x)
{ /* --- regularized Gamma function P */
if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
if (x <= 0) return 0; /* treat x = 0 as a special case */
if (isinf(n)) {
if (isinf(x)) return NPY_NAN;
return 0;
}
if (isinf(x)) return 1;
if (x < n+1) return _series(n, x) *exp(n *log(x) -x -logGamma(n));
return 1 -_cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
} /* GammaP() */
Expand All @@ -228,6 +233,11 @@ DEVICE double GammaQ (double n, double x)
{ /* --- regularized Gamma function Q */
if ((n <= 0) || (x < 0)) return NPY_NAN; /* check the function arguments */
if (x <= 0) return 1; /* treat x = 0 as a special case */
if (isinf(n)) {
if (isinf(x)) return NPY_NAN;
return 1;
}
if (isinf(x)) return 0;
if (x < n+1) return 1 -_series(n, x) *exp(n *log(x) -x -logGamma(n));
return _cfrac(n, x) *exp(n *log(x) -x -logGamma(n));
} /* GammaQ() */
Expand Down
21 changes: 21 additions & 0 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,13 @@ def __eq__(self, other):
def __hash__(self):
return hash(type(self))

def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2, *v)
else:
return v


chi2sf = Chi2SF(upgrade_to_float64, name="chi2sf")

Expand Down Expand Up @@ -677,6 +684,13 @@ def __eq__(self, other):
def __hash__(self):
return hash(type(self))

def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2, *v)
else:
return v


gammainc = GammaInc(upgrade_to_float, name="gammainc")

Expand Down Expand Up @@ -723,6 +737,13 @@ def __eq__(self, other):
def __hash__(self):
return hash(type(self))

def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2, *v)
else:
return v


gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")

Expand Down
20 changes: 20 additions & 0 deletions tests/scalar/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def test_gammainc_nan_c():
assert np.isnan(test_func(-1, -1))


def test_gammainc_inf_c():
x1 = pt.dscalar()
x2 = pt.dscalar()
y = gammainc(x1, x2)
test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isclose(test_func(np.inf, 1), sp.gammainc(np.inf, 1))
assert np.isclose(test_func(1, np.inf), sp.gammainc(1, np.inf))
assert np.isnan(test_func(np.inf, np.inf))


def test_gammaincc_python():
x1 = pt.dscalar()
x2 = pt.dscalar()
Expand All @@ -59,6 +69,16 @@ def test_gammaincc_nan_c():
assert np.isnan(test_func(-1, -1))


def test_gammaincc_inf_c():
x1 = pt.dscalar()
x2 = pt.dscalar()
y = gammaincc(x1, x2)
test_func = make_function(CLinker().accept(FunctionGraph([x1, x2], [y])))
assert np.isclose(test_func(np.inf, 1), sp.gammaincc(np.inf, 1))
assert np.isclose(test_func(1, np.inf), sp.gammaincc(1, np.inf))
assert np.isnan(test_func(np.inf, np.inf))


def test_gammal_nan_c():
x1 = pt.dscalar()
x2 = pt.dscalar()
Expand Down

0 comments on commit d8868cc

Please sign in to comment.