In [3]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline
%precision 4

import pyjags
import numpy as np

In [4]:
sample = [-27.020,3.570,8.191,9.898,9.603,9.945,10.056]
n = len(sample)

code = """
# bugs model for seven scientists

model{
        for (i in 1:n) {
                x[i] ~ dnorm(mu, lambda[i])
        }
        mu ~ dnorm(0,.001)
        for (i in 1:n) {
                sigma[i] ~ dunif(0.1,40)
                lambda[i] <- 1 / (sigma[i] * sigma[i])
        }
}
"""

myvars=['mu', 'lambda', 'sigma']

%time model = pyjags.Model(code, data=dict(x=sample, n=n), chains=4)
%time samples = model.sample(20000, vars=myvars)
samples.keys()

adapting: iterations 4000 of 4000, elapsed 0:00:00, remaining 0:00:00
CPU times: user 96 ms, sys: 224 ms, total: 320 ms
Wall time: 105 ms
sampling: iterations 75584 of 80000, elapsed 0:00:01, remaining 0:00:00
sampling: iterations 80000 of 80000, elapsed 0:00:01, remaining 0:00:00
CPU times: user 644 ms, sys: 48 ms, total: 692 ms
Wall time: 642 ms


dict_keys(['sigma', 'mu', 'lambda'])

In [5]:
def summary(samples, varname, p=95, burnin=0, thin=1):
    values = samples[varname]
    (nb, iter, chains) = values.shape
    
    for i in range(nb):
        data = values[i][burnin::thin]

        ci = np.percentile(data, [100-p, p])
        
        print('{:<6}[{}] mean = {:>12.4f} (std = {:>10.4f}), {}% credible interval [{:>5.1f} {:>5.1f}]'.format(
          varname, i, np.mean(data), np.std(data), p, *ci))

burnin = 5000
        
for varname in myvars:
    summary(samples, varname, burnin=burnin, thin=4)

mu    [0] mean =       8.8661 (std =     2.4904), 95% credible interval [  4.0  11.5]
lambda[0] mean =       0.0015 (std =     0.0015), 95% credible interval [  0.0   0.0]
lambda[1] mean =       0.1005 (std =     1.7060), 95% credible interval [  0.0   0.1]
lambda[2] mean =       0.6523 (std =     4.2909), 95% credible interval [  0.0   1.5]
lambda[3] mean =       1.8087 (std =     8.0567), 95% credible interval [  0.0   7.7]
lambda[4] mean =       1.6206 (std =     7.3929), 95% credible interval [  0.0   6.7]
lambda[5] mean =       1.7237 (std =     7.7004), 95% credible interval [  0.0   7.6]
lambda[6] mean =       1.6402 (std =     7.4845), 95% credible interval [  0.0   6.8]
sigma [0] mean =      28.8654 (std =     6.9913), 95% credible interval [ 16.6  39.0]
sigma [1] mean =      16.1146 (std =    10.4303), 95% credible interval [  3.1  36.1]
sigma [2] mean =      11.7285 (std =    10.6330), 95% credible interval [  0.8  34.3]
sigma [3] mean =      10.5593 (std =    10.5955), 95% 

In [6]:
samples['lambda'][1][4000:].shape

(16000, 4)

In [7]:
np.mean(sample)

3.4633

In [8]:
sample2 = [-27.020,3.570,8.191,9.898,9.603,9.945,10.056]
n = len(sample2)

code2 = """
# bugs model for seven scientists

model{
        for (i in 1:n) {
                x[i] ~ dnorm(mu, lambda[i])
                
                x_rep[i] ~ dnorm(mu, lambda[i])
                p[i] <- step(x[i] - x_rep[i])
        }
        mu ~ dnorm(0,.001)
        for (i in 1:n) {
                lambda[i] ~ dgamma(.01, .01)
                sigma[i] <- 1/sqrt(lambda[i])
        }
}

"""

burnin = 20000

myvars2=['mu', 'lambda', 'sigma', 'p']

%time model2 = pyjags.Model(code2, data=dict(x=sample2, n=n), chains=4)
%time samples2 = model2.sample(80000, vars=myvars2)

print (samples2['mu'].shape)

for varname in myvars2:
    summary(samples2, varname, burnin=burnin, thin=4)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 1.78 ms
sampling: iterations 240464 of 320000, elapsed 0:00:01, remaining 0:00:00
sampling: iterations 320000 of 320000, elapsed 0:00:01, remaining 0:00:00
CPU times: user 1.36 s, sys: 8 ms, total: 1.37 s
Wall time: 1.37 s
(1, 80000, 4)
mu    [0] mean =       9.8598 (std =     0.2368), 95% credible interval [  9.5  10.1]
lambda[0] mean =       0.0007 (std =     0.0010), 95% credible interval [  0.0   0.0]
lambda[1] mean =       0.0262 (std =     0.0622), 95% credible interval [  0.0   0.1]
lambda[2] mean =       0.4787 (std =     2.9538), 95% credible interval [  0.0   1.5]
lambda[3] mean =      31.4206 (std =    51.4238), 95% credible interval [  0.1 132.1]
lambda[4] mean =      15.0338 (std =    31.1242), 95% credible interval [  0.0  64.5]
lambda[5] mean =      30.4616 (std =    50.4231), 95% credible interval [  0.1 128.2]
lambda[6] mean =      23.3338 (std =    42.0002), 95% credible interval [  0.1 101.9]
sigma [0] mean =   