In [21]:
import blackjax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import os

import flax.linen as nn
from flax.core.frozen_dict import freeze, unfreeze

import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import optax
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

import matplotlib.pyplot as plt
import seaborn as sns

from bijax.mcmc import MCMC

from functools import partial
import regdata as rd

init = 1

import blackjax

In [22]:
num_schools = 8  # number of schools
treatment_effects = jnp.array(
    [28, 8, -3, 7, -1, 1, 18, 12], dtype=jnp.float32)  # treatment effects
treatment_stddevs = jnp.array(
    [15, 10, 16, 11, 9, 11, 10, 18], dtype=jnp.float32)  # treatment SE

In [23]:
prior = {"mu": tfd.Normal(0, 10), 
         "tou": tfd.LogNormal(5,1), 
         "theta": lambda mu, tou: tfd.MultivariateNormalDiag(mu.repeat(8), tou.repeat(8))}

def log_likelihood_fn(latent_sample, data, inputs, **kwargs):
    return tfd.Normal(latent_sample["theta"], inputs["treatment_stddevs"]).log_prob(data).sum()

In [24]:
model = MCMC(prior, log_likelihood_fn)

In [25]:
params = {
    'mu' : jnp.array(10.0),
    'tou': jnp.array(15.0),
    'theta': jnp.ones(treatment_effects.shape[0])*8
}

In [26]:
import arviz as az
import jax.numpy as jnp
import jax

def arviz_trace_from_states(states, info, burn_in=0):
    """
    args:
    ...........
    states: contains samples returned by blackjax model (i.e HMCState)
    info: conatins the meta info returned by blackjax model (i.e HMCinfo)
    
    returns:
    ...........
    trace: arviz trace object
    """
    if isinstance(states.position, jnp.DeviceArray):  #if states.position is array of samples 
        samples = {"samples":jnp.swapaxes(states.position,0,1)}
        divergence = jnp.swapaxes(info.is_divergent, 0, 1)
     
    else: # if states.position is dict 
        samples = {}        
        for param in states.position.keys():
            ndims = len(states.position[param].shape)
            if ndims >= 2:
                samples[param] = jnp.swapaxes(states.position[param], 0, 1)[:, burn_in:]  # swap n_samples and n_chains
                divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1)

            if ndims == 1:
                divergence = info.is_divergent
                samples[param] = states.position[param]
                
    trace_posterior = az.convert_to_inference_data(samples)
    trace_sample_stats = az.convert_to_inference_data({"diverging": divergence}, group="sample_stats")
    trace = az.concat(trace_posterior, trace_sample_stats)
    return trace

def inference_loop_multiple_chains(rng_key, kernel, initial_states, num_samples, num_chains):
    '''
    returns dict: {"states": states, "info": info}
    Visit this page for more info: https://blackjax-devs.github.io/blackjax/examples/Introduction.html
    '''
    def one_step(states, rng_key):
        keys = jax.random.split(rng_key, num_chains)
        states, infos = jax.vmap(kernel)(keys, states) 
        return states, (states, infos)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_states, keys)

    return (states, infos)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    '''
    returns (states, info)
    Visit this page for more info: https://blackjax-devs.github.io/blackjax/examples/Introduction.html
    '''
    @jax.jit
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return (states, infos)

In [27]:
log_post_fn = partial(model.log_joint, inputs = {'treatment_stddevs': treatment_stddevs}, batch=treatment_effects)

import blackjax
rng_key1 = jax.random.PRNGKey(0)
rng_key2,rng_key3 = jax.random.split(rng_key1)

num_chains = 3

nuts = blackjax.nuts(log_post_fn,step_size= 0.09,inverse_mass_matrix = jnp.ones(10)*1)

params_re = {
    'mu' : jax.random.uniform(key=rng_key1,shape=(num_chains,)),
    'tou': jax.random.uniform(key=rng_key2,shape=(num_chains,)),
    'theta': jax.random.uniform(key=rng_key3,shape=(num_chains,num_schools,)),
}

initial_states = jax.vmap(nuts.init)(params_re)
states_cent,info_cent = inference_loop_multiple_chains(rng_key1,kernel= nuts.step,initial_states= initial_states, num_samples=1500, num_chains=num_chains)

In [28]:
trace_centered = arviz_trace_from_states(states=states_cent,info=info_cent)
smry = az.summary(trace_centered)
smry

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,5.791,5.636,-4.318,16.964,0.1,0.076,3161.0,2996.0,1.0
theta[0],14.975,10.494,-4.447,34.381,0.229,0.166,2135.0,2667.0,1.0
theta[1],7.149,8.216,-8.468,22.419,0.128,0.109,4110.0,3317.0,1.0
theta[2],2.54,10.729,-18.022,22.555,0.214,0.181,2540.0,2186.0,1.0
theta[3],6.562,8.293,-9.924,21.543,0.133,0.11,3855.0,2804.0,1.0
theta[4],1.891,7.341,-11.477,15.606,0.111,0.111,4324.0,3251.0,1.0
theta[5],3.403,8.685,-12.627,20.166,0.145,0.123,3640.0,3192.0,1.0
theta[6],12.97,8.17,-1.345,29.492,0.154,0.115,2863.0,2763.0,1.0
theta[7],8.251,11.029,-12.731,29.199,0.217,0.184,2618.0,2296.0,1.0
tou,13.17,6.765,2.929,25.474,0.197,0.139,1134.0,1685.0,1.0
