-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
test_special_functions.py
60 lines (42 loc) · 1.35 KB
/
test_special_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from theano import function
import theano.tensor as tt
import pymc3.distributions.special as ps
import scipy.special as ss
import numpy as np
from .checks import close_to
def test_functions():
xvals = list(map(np.atleast_1d, [.01, .1, 2, 100, 10000]))
x = tt.dvector('x')
x.tag.test_value = xvals[0]
p = tt.iscalar('p')
p.tag.test_value = 1
gammaln = function([x], ps.gammaln(x))
psi = function([x], ps.psi(x))
function([x, p], ps.multigammaln(x, p))
for x in xvals:
yield check_vals, gammaln, ss.gammaln, x
for x in xvals[1:]:
yield check_vals, psi, ss.psi, x
"""
scipy.special.multigammaln gives bad values if you pass a non scalar to a
In [14]:
import scipy.special
scipy.special.multigammaln([2.1], 3)
Out[14]:
array([ 1.76253257, 1.60450306, 1.66722239])
"""
def t_multigamma():
xvals = list(map(np.atleast_1d, [0, .1, 2, 100]))
x = tt.dvector('x')
x.tag.test_value = xvals[0]
p = tt.iscalar('p')
p.tag.test_value = 1
multigammaln = function([x, p], ps.multigammaln(x, p))
def ssmultigammaln(a, b):
return ss.multigammaln(a[0], b)
for p in [0, 1, 2, 3, 4, 100]:
for x in xvals:
yield check_vals, multigammaln, ssmultigammaln, x, p
def check_vals(fn1, fn2, *args):
v = fn1(*args)
close_to(v, fn2(*args), 1e-6)