In [1]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline
%precision 4

import pyjags
import numpy as np

In [19]:
sample = [-27.020,3.570,8.191,9.898,9.603,9.945,10.056]
n = len(sample)

code = """
# bugs model for seven scientists

model{
        for (i in 1:n) {
                x[i] ~ dnorm(mu, lambda[i])
        }
        mu ~ dnorm(0,.001)
        for (i in 1:n) {
                sigma[i] ~ dunif(0.1,40)
                lambda[i] <- 1 / (sigma[i] * sigma[i])
        }
}
"""

myvars=['mu', 'lambda', 'sigma']

%time model = pyjags.Model(code, data=dict(x=sample, n=n), chains=4)
%time samples = model.sample(20000, vars=myvars)
samples.keys()

adapting: iterations 4000 of 4000, elapsed 0:00:00, remaining 0:00:00
CPU times: user 29.2 ms, sys: 421 µs, total: 29.6 ms
Wall time: 29.2 ms
sampling: iterations 80000 of 80000, elapsed 0:00:01, remaining 0:00:00
sampling: iterations 80000 of 80000, elapsed 0:00:01, remaining 0:00:00
CPU times: user 556 ms, sys: 1.02 ms, total: 557 ms
Wall time: 555 ms


dict_keys(['lambda', 'sigma', 'mu'])

In [20]:
def summary(samples, varname, p=95, burnin=0, thin=1):
    values = samples[varname]
    (nb, iter, chains) = values.shape
    
    for i in range(nb):
        data = values[i][burnin::thin]

        ci = np.percentile(data, [100-p, p])
        
        print('{:<6}[{}] mean = {:>12.4f} (std = {:>10.4f}), {}% credible interval [{:>5.1f} {:>5.1f}]'.format(
          varname, i, np.mean(data), np.std(data), p, *ci))

burnin = 5000
        
for varname in myvars:
    summary(samples, varname, burnin=burnin, thin=4)

mu    [0] mean =       8.9514 (std =     2.4275), 95% credible interval [  4.1  11.6]
lambda[0] mean =       0.0015 (std =     0.0011), 95% credible interval [  0.0   0.0]
lambda[1] mean =       0.0795 (std =     1.1251), 95% credible interval [  0.0   0.1]
lambda[2] mean =       0.6792 (std =     4.6330), 95% credible interval [  0.0   1.5]
lambda[3] mean =       1.9033 (std =     8.2346), 95% credible interval [  0.0   9.0]
lambda[4] mean =       1.6225 (std =     7.4375), 95% credible interval [  0.0   6.7]
lambda[5] mean =       1.9381 (std =     8.4027), 95% credible interval [  0.0   8.8]
lambda[6] mean =       1.7623 (std =     8.0225), 95% credible interval [  0.0   7.4]
sigma [0] mean =      28.8989 (std =     6.9718), 95% credible interval [ 16.7  38.9]
sigma [1] mean =      16.1981 (std =    10.4464), 95% credible interval [  3.2  36.3]
sigma [2] mean =      11.7016 (std =    10.6201), 95% credible interval [  0.8  34.0]
sigma [3] mean =      10.2340 (std =    10.5299), 95% 

In [5]:
samples['lambda'][1][4000:].shape

(16000, 4)

In [6]:
np.mean(sample)

3.4633

In [21]:
sample2 = [-27.020,3.570,8.191,9.898,9.603,9.945,10.056]
n = len(sample2)

code2 = """
# bugs model for seven scientists

model{
        for (i in 1:n) {
                x[i] ~ dnorm(mu, lambda[i])
                
                x_rep[i] ~ dnorm(mu, lambda[i])
                p[i] <- step(x[i] - x_rep[i])
        }
        mu ~ dnorm(0,.001)
        for (i in 1:n) {
                lambda[i] ~ dgamma(.01, .01)
                sigma[i] <- 1/sqrt(lambda[i])
        }
}

"""

burnin = 20000

myvars2=['mu', 'lambda', 'sigma', 'p']

%time model2 = pyjags.Model(code2, data=dict(x=sample2, n=n), chains=4)
%time samples2 = model2.sample(80000, vars=myvars2)

print (samples2['mu'].shape)

for varname in myvars2:
    summary(samples2, varname, burnin=burnin, thin=4)

CPU times: user 1.93 ms, sys: 282 µs, total: 2.21 ms
Wall time: 1.83 ms
sampling: iterations 160888 of 320000, elapsed 0:00:01, remaining 0:00:01
sampling: iterations 320000 of 320000, elapsed 0:00:01, remaining 0:00:00
sampling: iterations 320000 of 320000, elapsed 0:00:01, remaining 0:00:00
CPU times: user 1.23 s, sys: 9.86 ms, total: 1.24 s
Wall time: 1.24 s
(1, 80000, 4)
mu    [0] mean =       9.8605 (std =     0.2315), 95% credible interval [  9.5  10.1]
lambda[0] mean =       0.0007 (std =     0.0010), 95% credible interval [  0.0   0.0]
lambda[1] mean =       0.0258 (std =     0.0359), 95% credible interval [  0.0   0.1]
lambda[2] mean =       0.4846 (std =     2.6583), 95% credible interval [  0.0   1.5]
lambda[3] mean =      30.9149 (std =    50.7218), 95% credible interval [  0.1 129.8]
lambda[4] mean =      14.7713 (std =    30.0646), 95% credible interval [  0.0  63.6]
lambda[5] mean =      30.8313 (std =    50.8860), 95% credible interval [  0.1 129.4]
lambda[6] mean =    