# Import Packages

In [1]:
import arviz as az
import numpy as np
import pyjags
import xarray as xr
xr.set_options(display_style="html");

This notebook illustrates the modeling of the eight schools dataset which is a classical example in Bayesian hierarchical analysis. For each school, the given data are the estimated treatment effect and the standard error of the treatment effect of an experiment conducted in the late 1970s concerning SAT-V scores.

# Construct Data Dictionary

In [2]:
# eight_schools_data = {
#     "J": 8, 
#     "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
#     "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
# }

eight_schools_data = {
    "J": 8, 
    "y": np.array([28.39, 7.94, -2.75 , 6.82, -0.64, 0.63, 18.01, 12.16]),
    "sigma": np.array([14.9, 10.2, 16.3, 11.0, 9.4, 11.4, 10.4, 17.6]),
}

# JAGS Model Code

## Prior Model

In [3]:
eight_school_prior_model_code = ''' 
model {
    mu ~ dnorm(0.0, 1.0/25)
    tau ~ dt(0.0, 1.0/25, 1.0) T(0, )
    for (j in 1:J) {
        theta_tilde[j] ~ dnorm(0.0, 1.0)
    }
}
'''

## Posterior Model

In [4]:
eight_school_posterior_model_code = ''' 
model {
    mu ~ dnorm(0.0, 1.0/25)
    tau ~ dt(0.0, 1.0/25, 1.0) T(0, )
    for (j in 1:J) {
        theta_tilde[j] ~ dnorm(0.0, 1.0)
        y[j] ~ dnorm(mu + tau * theta_tilde[j], 1.0/(sigma[j]^2))
        log_like[j] = logdensity.norm(y[j], mu + tau * theta_tilde[j], 1.0/(sigma[j]^2))
    }
}
'''

## Define Parameters

In [5]:
parameters = ['mu', 'tau', 'theta_tilde']
variables = parameters + ['log_like']

# Construct JAGS Model

## Prior Model

In [6]:
jags_prior_model = pyjags.Model(
    code=eight_school_prior_model_code, 
    data={"J": 8}, 
    chains=4, 
    threads=4,
    chains_per_thread=1
)

## Posterior Model

In [7]:
jags_posterior_model = pyjags.Model(
    code=eight_school_posterior_model_code, 
    data=eight_schools_data, 
    chains=4, 
    threads=4,
    chains_per_thread=1
)

adapting: iterations 4000 of 4000, elapsed 0:00:00, remaining 0:00:00


# Draw Samples

Draw burn-in sample of 1000 iterations

In [8]:
jags_prior_model.sample(1000, vars=[])
jags_posterior_model.sample(1000, vars=[])

sampling: iterations 4000 of 4000, elapsed 0:00:00, remaining 0:00:00
sampling: iterations 4000 of 4000, elapsed 0:00:00, remaining 0:00:00


{}

Draw actual sample of 5000 iterations

In [9]:
jags_prior_samples = jags_prior_model.sample(5000, vars=parameters)
jags_posterior_samples = jags_posterior_model.sample(5000, vars=variables)

sampling: iterations 20000 of 20000, elapsed 0:00:00, remaining 0:00:00
sampling: iterations 20000 of 20000, elapsed 0:00:00, remaining 0:00:00


Convert dictionary of samples to Arviz object

In [10]:
pyjags_data = az.from_pyjags(
    posterior=jags_posterior_samples, 
    prior=jags_prior_samples, 
    log_likelihood={'y': 'log_like'}, 
    save_warmup=True, 
    warmup_iterations=1000
)
pyjags_data

# Compute Diagnostics

## Compute Gelman-Rubin Statistic

In [11]:
az.rhat(pyjags_data)

A Gelman-Rubin substantially different from unity indicates non-convergence of the Markov chains.

## Generate Autocorrelation Plot

In [12]:
az.plot_autocorr(pyjags_data, combined=True);

## Compute effective sample size

In [13]:
az.ess(pyjags_data)

## Generate Trace Plot

In [14]:
az.plot_trace(pyjags_data);

## Compute Summary Statistics

In [15]:
az.summary(pyjags_data)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,4.397,3.342,-1.987,10.575,0.029,0.021,13079.0,14122.0,1.0
tau,3.638,3.333,0.0,9.468,0.058,0.044,4567.0,4075.0,1.0
theta_tilde[0],0.331,0.991,-1.534,2.19,0.009,0.006,12439.0,14459.0,1.0
theta_tilde[1],0.101,0.94,-1.61,1.929,0.008,0.006,14380.0,14946.0,1.0
theta_tilde[2],-0.086,0.966,-1.879,1.756,0.008,0.006,14855.0,15612.0,1.0
theta_tilde[3],0.052,0.938,-1.665,1.857,0.008,0.006,15223.0,15001.0,1.0
theta_tilde[4],-0.148,0.933,-1.954,1.55,0.008,0.006,14483.0,14508.0,1.0
theta_tilde[5],-0.08,0.953,-1.89,1.708,0.008,0.006,14876.0,15198.0,1.0
theta_tilde[6],0.337,0.962,-1.505,2.147,0.008,0.006,13037.0,14890.0,1.0
theta_tilde[7],0.09,0.972,-1.752,1.911,0.008,0.005,15686.0,15463.0,1.0


## Generate Posterior Plot

In [16]:
az.plot_posterior(pyjags_data);

Compute WAIC

In [17]:
az.waic(pyjags_data);

# Do Everything in PyMC3 for Comparison

## Import Packages and Set Simulation Parameters

In [18]:
import pymc3 as pm

draws = 5000
chains = 4

# eight_schools_data = {
#     "J": 8,
#     "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
#     "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
# }

## Construct PyMC3 Model and Generate Posterior Samples

In [19]:
with pm.Model() as model:
    mu = pm.Normal("mu", mu=0, sd=5)
    tau = pm.HalfCauchy("tau", beta=5)
    theta_tilde = pm.Normal("theta_tilde", mu=0, sd=1, shape=eight_schools_data["J"])
    theta = pm.Deterministic("theta", mu + tau * theta_tilde)
    pm.Normal(
        "obs", mu=theta, sd=eight_schools_data["sigma"], observed=eight_schools_data["y"]
    )

    trace = pm.sample(draws, chains=chains)
    prior = pm.sample_prior_predictive()
#     posterior_predictive = pm.sample_posterior_predictive(trace)

    pymc3_data = az.from_pymc3(
        trace=trace,
        prior=prior,
#         posterior_predictive=posterior_predictive,
        coords={"school": np.arange(eight_schools_data["J"])},
        dims={"theta": ["school"], "theta_tilde": ["school"]},
    )
pymc3_data

  trace = pm.sample(draws, chains=chains)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [theta_tilde, tau, mu]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 7 seconds.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
There were 8 divergences after tuning. Increase `target_accept` or reparameterize.
There were 43 divergences after tuning. Increase `target_accept` or reparameterize.
There were 36 divergences after tuning. Increase `target_accept` or reparameterize.


## Generate Trace Plot

In [20]:
az.plot_trace(pymc3_data);

## Compute Summary Statistics

In [21]:
az.summary(pymc3_data)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,4.415,3.337,-1.928,10.616,0.025,0.018,18070.0,11139.0,1.0
theta_tilde[0],0.332,0.985,-1.465,2.215,0.007,0.007,20657.0,13552.0,1.0
theta_tilde[1],0.095,0.943,-1.733,1.838,0.006,0.007,23206.0,13774.0,1.0
theta_tilde[2],-0.073,0.961,-1.824,1.775,0.006,0.007,22894.0,14402.0,1.0
theta_tilde[3],0.058,0.926,-1.649,1.82,0.006,0.008,22736.0,13074.0,1.0
theta_tilde[4],-0.149,0.93,-1.914,1.586,0.006,0.007,21147.0,13156.0,1.0
theta_tilde[5],-0.079,0.949,-1.828,1.749,0.006,0.007,22698.0,14636.0,1.0
theta_tilde[6],0.342,0.948,-1.419,2.142,0.007,0.007,19376.0,12937.0,1.0
theta_tilde[7],0.086,0.975,-1.75,1.909,0.007,0.008,21701.0,12909.0,1.0
tau,3.62,3.164,0.001,9.234,0.03,0.023,10238.0,8313.0,1.0


In [22]:
az.summary(pyjags_data)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,4.397,3.342,-1.987,10.575,0.029,0.021,13079.0,14122.0,1.0
tau,3.638,3.333,0.0,9.468,0.058,0.044,4567.0,4075.0,1.0
theta_tilde[0],0.331,0.991,-1.534,2.19,0.009,0.006,12439.0,14459.0,1.0
theta_tilde[1],0.101,0.94,-1.61,1.929,0.008,0.006,14380.0,14946.0,1.0
theta_tilde[2],-0.086,0.966,-1.879,1.756,0.008,0.006,14855.0,15612.0,1.0
theta_tilde[3],0.052,0.938,-1.665,1.857,0.008,0.006,15223.0,15001.0,1.0
theta_tilde[4],-0.148,0.933,-1.954,1.55,0.008,0.006,14483.0,14508.0,1.0
theta_tilde[5],-0.08,0.953,-1.89,1.708,0.008,0.006,14876.0,15198.0,1.0
theta_tilde[6],0.337,0.962,-1.505,2.147,0.008,0.006,13037.0,14890.0,1.0
theta_tilde[7],0.09,0.972,-1.752,1.911,0.008,0.005,15686.0,15463.0,1.0
