In [1]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import numpyro as npr
import numpyro.distributions as dists
import matplotlib
import matplotlib.cm as cm
import matplotlib.dates as mdates
import tqdm
from math import *
import seaborn as sns
sns.set_style('white')

matplotlib.rcParams['figure.figsize'] = (8,5)
matplotlib.rcParams['font.size'] = 10
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['font.serif'] = 'Times'
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['lines.linewidth'] = 1
plt = matplotlib.pyplot

npr.set_platform('gpu')

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform




### Model

In [2]:
# Hyperparams
sigma = 0.09
nu = 12
NUM_STEPS = 100

rng_key = jax.random.PRNGKey(9999)


def model(y, theta=None, rng_key=random.PRNGKey(1)):
    """ PyMC3 example http://num.pyro.ai/en/0.6.0/examples/stochastic_volatility.html """
    rng_key, rng_subkey = random.split(rng_key)

    num_steps = len(y) if y is not None else NUM_STEPS
    
    if theta is None:
        log_vol = npr.sample(
            'theta', dists.GaussianRandomWalk(scale=sigma, num_steps=num_steps), rng_key=rng_subkey
        )
    else:
        log_vol = theta
    
    rng_key, rng_subkey = random.split(rng_key)
    returns = npr.sample('y', dists.StudentT(df=nu, loc=0., scale=jnp.exp(log_vol)),
                         rng_key=rng_subkey, obs=y)
    
    if y is None:
        return returns
    else:  
        return log_vol  # Given y, sample latent

### Approx. posteriors

In [3]:
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import *


def train_vb_diag(rng_key, y, pbar=True):    
    guide = AutoDiagonalNormal(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
        
    return guide, svi_result.params


def train_vb_full(rng_key, y, pbar=True):    
    guide = AutoMultivariateNormal(model)
    lr = 5e-4  # Unstable with large lr
    n_iter = 10000  # Compensate with larger num. of iterations
    
    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
        
    return guide, svi_result.params


def train_laplace(rng_key, y, pbar=True):    
    guide = AutoLaplaceApproximation(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=1))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
        
    return guide, svi_result.params

### Sampling $\pi_G$

In [4]:
from numpyro.infer import MCMC, NUTS


POSTERIORS = ['vb_diag', 'vb_full', 'mcmc_long', 'mcmc_short']
POSTERIOR_FUNCS = {
    'vb_diag': train_vb_diag, 
    'vb_full': train_vb_full, 
    # 'laplace': train_laplace,
}


def sample_gibbs_prior(rng_key, posterior, n_chains=5, T=1000):
    r_hats, autocorrs = [], []

    theta_samples = [[] for i in range(n_chains)]
    y_ts = []
    
    for i in range(n_chains):
        rng_key, subkey = random.split(rng_key)
        y_t_i = model(y=None, rng_key=subkey)
        y_ts.append(100*y_t_i)  # Overdispersed initialization
    
    for t in tqdm.trange(T): 
        for i in range(n_chains):
            rng_key, *subkeys = random.split(rng_key, 5)

            # Get q(theta | y_t)
            if 'mcmc' not in posterior:
                guide_t_i, params_t_i = POSTERIOR_FUNCS[posterior](subkeys[0], y=y_ts[i], pbar=False)
                theta_t_i = guide_t_i.sample_posterior(subkeys[1], params_t_i)['theta']
            else:
                num_warmup = 20 if 'long' in posterior else 5
                num_samples = 20 if 'long' in posterior else 5
                mcmc = MCMC(NUTS(model), num_warmup=num_warmup, num_samples=num_samples, progress_bar=False)
                mcmc.run(subkeys[2], y_ts[i])
                theta_t_i = mcmc.get_samples()['theta'][-1]
                
            theta_samples[i].append(np.array(theta_t_i).copy())

            # Sample y_t
            y_ts[i]  = model(y=None, theta=theta_t_i, rng_key=subkeys[3])  
        
    # Shape: (n_chains, n_samples, n_dim)
    return np.array(theta_samples)

### Gather samples

In [5]:
for posterior in POSTERIORS:
    rng_key, rng_subkey = random.split(rng_key)

    theta_samples = sample_gibbs_prior(rng_subkey, posterior, n_chains=5, T=500)
    theta_samples = np.array(theta_samples)

    np.save(f'../../results/convergence/multi_chains_{posterior}.npy', theta_samples)