In [1]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline
%precision 4

import pyjags
import numpy as np

In [2]:
sample = [-27.020,3.570,8.191,9.898,9.603,9.945,10.056]
n = len(sample)

code = """
# bugs model for seven scientists

model{
        for (i in 1:n) {
                x[i] ~ dnorm(mu, lambda[i])
        }
        mu ~ dnorm(0,.001)
        for (i in 1:n) {
                sigma[i] ~ dunif(0.1,40)
                lambda[i] <- 1 / (sigma[i] * sigma[i])
        }
}
"""

myvars=['mu', 'lambda', 'sigma']

%time model = pyjags.Model(code, data=dict(x=sample, n=n), chains=4)
%time samples = model.sample(20000, vars=myvars)
samples.keys()

adapting: iterations 4000 of 4000, elapsed 0:00:00, remaining 0:00:00
CPU times: user 100 ms, sys: 340 ms, total: 440 ms
Wall time: 114 ms
sampling: iterations 71772 of 80000, elapsed 0:00:01, remaining 0:00:00
sampling: iterations 80000 of 80000, elapsed 0:00:01, remaining 0:00:00
CPU times: user 720 ms, sys: 20 ms, total: 740 ms
Wall time: 727 ms


dict_keys(['sigma', 'mu', 'lambda'])

In [3]:
def summary(samples, varname, p=95, burnin=0, thin=1):
    values = samples[varname]
    (nb, iter, chains) = values.shape
    
    for i in range(nb):
        data = values[i][burnin::thin]

        ci = np.percentile(data, [100-p, p])
        
        print('{:<6}[{}] mean = {:>12.4f} (std = {:>10.4f}), {}% credible interval [{:>5.1f} {:>5.1f}]'.format(
          varname, i, np.mean(data), np.std(data), p, *ci))

burnin = 5000
        
for varname in myvars:
    summary(samples, varname, burnin=burnin, thin=4)

mu    [0] mean =       8.9248 (std =     2.4111), 95% credible interval [  4.2  11.6]
lambda[0] mean =       0.0015 (std =     0.0011), 95% credible interval [  0.0   0.0]
lambda[1] mean =       0.1152 (std =     1.7023), 95% credible interval [  0.0   0.1]
lambda[2] mean =       0.7373 (std =     4.9515), 95% credible interval [  0.0   1.5]
lambda[3] mean =       1.9164 (std =     8.1970), 95% credible interval [  0.0   8.8]
lambda[4] mean =       1.6728 (std =     7.4455), 95% credible interval [  0.0   7.2]
lambda[5] mean =       2.0237 (std =     8.5688), 95% credible interval [  0.0   9.4]
lambda[6] mean =       1.6147 (std =     7.3730), 95% credible interval [  0.0   6.9]
sigma [0] mean =      28.9428 (std =     6.9503), 95% credible interval [ 16.8  38.9]
sigma [1] mean =      16.2200 (std =    10.3956), 95% credible interval [  3.2  36.1]
sigma [2] mean =      11.6421 (std =    10.5403), 95% credible interval [  0.8  34.0]
sigma [3] mean =      10.2735 (std =    10.5037), 95% 

In [4]:
samples['lambda'][1][4000:].shape

(16000, 4)

In [5]:
np.mean(sample)

3.4633

In [6]:
sample2 = [-27.020,3.570,8.191,9.898,9.603,9.945,10.056]
n = len(sample2)

code2 = """
# bugs model for seven scientists

model{
        for (i in 1:n) {
                x[i] ~ dnorm(mu, lambda[i])
                
                x_rep[i] ~ dnorm(mu, lambda[i])
                p[i] <- step(x[i] - x_rep[i])
        }
        mu ~ dnorm(0,.001)
        for (i in 1:n) {
                lambda[i] ~ dgamma(.01, .01)
                sigma[i] <- 1/sqrt(lambda[i])
        }
}

"""

burnin = 20000

myvars2=['mu', 'lambda', 'sigma', 'p']

%time model2 = pyjags.Model(code2, data=dict(x=sample2, n=n), chains=4)
%time samples2 = model2.sample(80000, vars=myvars2)

print (samples2['mu'].shape)

for varname in myvars2:
    summary(samples2, varname, burnin=burnin, thin=4)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 1.2 ms
sampling: iterations 155556 of 320000, elapsed 0:00:01, remaining 0:00:01
sampling: iterations 285324 of 320000, elapsed 0:00:01, remaining 0:00:00
sampling: iterations 320000 of 320000, elapsed 0:00:01, remaining 0:00:00
CPU times: user 1.3 s, sys: 12 ms, total: 1.31 s
Wall time: 1.31 s
(1, 80000, 4)
mu    [0] mean =       9.8595 (std =     0.2437), 95% credible interval [  9.5  10.1]
lambda[0] mean =       0.0007 (std =     0.0010), 95% credible interval [  0.0   0.0]
lambda[1] mean =       0.0260 (std =     0.0367), 95% credible interval [  0.0   0.1]
lambda[2] mean =       0.4902 (std =     3.1438), 95% credible interval [  0.0   1.5]
lambda[3] mean =      31.0347 (std =    51.1958), 95% credible interval [  0.1 130.5]
lambda[4] mean =      14.9909 (std =    30.7486), 95% credible interval [  0.0  64.3]
lambda[5] mean =      30.7824 (std =    51.0817), 95% credible interval [  0.1 129.8]
lambda[6] mean =      23.2234 (s