In [1]:
import warnings
warnings.filterwarnings('ignore')

%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import numpyro as npr
import matplotlib
import matplotlib.cm as cm
import matplotlib.dates as mdates
import tqdm
from math import *
import numpyro.distributions as dist
import seaborn as sns
sns.set_style('white')

matplotlib.rcParams['figure.figsize'] = (8,5)
matplotlib.rcParams['font.size'] = 10
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['font.serif'] = 'Times'
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['lines.linewidth'] = 1
plt = matplotlib.pyplot

npr.set_platform('gpu')

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


### Model

In [2]:
L = 10

def model_true(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 5)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # The true likelihood, sum of LogNormal rvs.
    with npr.plate('L', L):
        x = npr.sample('X', dist.LogNormal(mu, jnp.sqrt(sigma_sq)), rng_key=subkeys[2])
        
    out = npr.sample('Y', dist.Delta(x.sum(0)), rng_key=subkeys[3], obs=y)
        
    if theta is None and y is not None:
        return (mu, sigma_sq)
    else:
        return out
    
    
def model_abc(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 4)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # Approximate likelihood
    beta_sq = jnp.log((jnp.exp(sigma_sq)-1)/L + 1)
    alpha = mu + jnp.log(L) + 0.5*(sigma_sq - beta_sq)
        
    out = npr.sample('Y', dist.LogNormal(alpha, jnp.sqrt(beta_sq)), rng_key=subkeys[2], obs=y)
    
    if theta is None and y is not None:
        return (mu, sigma_sq)
    else:
        return out

### Approx. posteriors

In [3]:
from numpyro.infer import SVI, Trace_ELBO, MCMC, NUTS
from numpyro.infer.autoguide import *


def laplace(rng_key, model, y, pbar=False):    
    key, *subkeys = random.split(rng_key, 4)
    
    guide = AutoLaplaceApproximation(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=1))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
    
    mu = guide.sample_posterior(subkeys[1], svi_result.params)['mu']
    sigma_sq = guide.sample_posterior(subkeys[2], svi_result.params)['sigma_sq']
    
    return mu, sigma_sq

### Sampling $\pi_G$

In [4]:
def sample_gibbs_prior(rng_key, n_chains=5, T=1000):
    r_hats, autocorrs = [], []

    theta_samples = [[] for i in range(n_chains)]
    y_ts = []
    
    for i in range(n_chains):
        rng_key, subkey = random.split(rng_key)
        y_t_i = model_true(y=None, theta=None, rng_key=subkey)
        y_ts.append(100*y_t_i)  # Overdispersed initialization
    
    for t in tqdm.trange(T): 
        for i in range(n_chains):
            rng_key, *subkeys = random.split(rng_key, 3)
            
            # Get q(theta | y_t)
            theta_t_i = laplace(subkeys[0], model_abc, y=y_ts[i])              
            theta_samples[i].append(np.array(theta_t_i).copy())

            # Sample y_t
            y_ts[i]  = model_true(y=None, theta=theta_t_i, rng_key=subkeys[1])
            
            while y_ts[i] == inf:
                rng_key, subkey = random.split(rng_key)
                y_ts[i]  = model_true(y=None, theta=theta_t_i, rng_key=subkey)
        
    # Shape: (n_chains, n_samples, n_dim)
    return np.array(theta_samples)

### Gather samples

In [6]:
theta_samples = sample_gibbs_prior(random.PRNGKey(1234), n_chains=5, T=500)
theta_samples = np.array(theta_samples)

np.save('../../results/convergence/multi_chains_laplace.npy', theta_samples)

(5, 1000, 2)
