In [1]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import numpyro as npr
import numpyro.distributions as dists
import matplotlib
import matplotlib.cm as cm
import matplotlib.dates as mdates
import tqdm
from math import *
import seaborn as sns
sns.set_style('white')

matplotlib.rcParams['figure.figsize'] = (8,5)
matplotlib.rcParams['font.size'] = 10
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['font.serif'] = 'Times'
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['lines.linewidth'] = 1
plt = matplotlib.pyplot

npr.set_platform('gpu')

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


### Model

In [2]:
# Hyperparams
sigma = 0.09
nu = 12
NUM_STEPS = 100


def model(y, theta=None, rng_key=random.PRNGKey(1)):
    """ PyMC3 example http://num.pyro.ai/en/0.6.0/examples/stochastic_volatility.html """
    rng_key, rng_subkey = random.split(rng_key)

    num_steps = len(y) if y is not None else NUM_STEPS
    
    if theta is None:
        log_vol = npr.sample(
            'theta', dists.GaussianRandomWalk(scale=sigma, num_steps=num_steps), rng_key=rng_subkey
        )
    else:
        log_vol = theta
    
    rng_key, rng_subkey = random.split(rng_key)
    returns = npr.sample('y', dists.StudentT(df=nu, loc=0., scale=jnp.exp(log_vol)),
                         rng_key=rng_subkey, obs=y)
    
    if theta is None and y is None:  
        return log_vol  # Sample latent
    else:  
        return returns  # Given latent, sample y

### Approx. posteriors

In [3]:
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import *
from jax import jit


@jit
def vb_diag(rng_key, y):    
    key, *subkeys = random.split(rng_key, 3)
    
    guide = AutoDiagonalNormal(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(subkeys[0], n_iter, y=y, progress_bar=False)
    
    theta = guide.sample_posterior(subkeys[1], svi_result.params)['theta']
    return theta

@jit
def vb_full(rng_key, y):    
    key, *subkeys = random.split(rng_key, 3)
    
    guide = AutoMultivariateNormal(model)
    lr = 5e-4  # Unstable with large lr
    n_iter = 10000  # Compensate with larger num. of iterations
    
    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(subkeys[0], n_iter, y=y, progress_bar=False)
    
    theta = guide.sample_posterior(subkeys[1], svi_result.params)['theta']
    return theta

@jit
def mcmc_short(rng_key, y):    
    mcmc = MCMC(NUTS(model), num_warmup=5, num_samples=5, progress_bar=False)
    mcmc.run(rng_key, y)

    theta = mcmc.get_samples()['theta'][-1]
    return theta


@jit
def mcmc_long(rng_key, y):    
    mcmc = MCMC(NUTS(model), num_warmup=20, num_samples=20, progress_bar=False)
    mcmc.run(rng_key, y)

    theta = mcmc.get_samples()['theta'][-1]
    return theta

### Sampling $\pi_G$

In [4]:
POSTERIORS = ['vb_diag', 'vb_full', 'mcmc_long', 'mcmc_short']
POSTERIOR_FUNCS = {
    'vb_diag': vb_diag, 
    'vb_full': vb_full, 
    'mcmc_short': mcmc_short,
    'mcmc_long': mcmc_long
}


def sample_gibbs_prior(rng_key, posterior, n_chains=5, T=1000):
    theta_samples = [[] for i in range(n_chains)]
    y_ts = []
    
    for i in range(n_chains):
        rng_key, subkey = random.split(rng_key)
        y_t_i = model(y=None, rng_key=subkey)
        y_ts.append(100*y_t_i)  # Overdispersed initialization
    
    for t in tqdm.trange(T): 
        for i in range(n_chains):
            rng_key, *subkeys = random.split(rng_key, 3)

            # Get q(theta | y_t)
            theta_t_i = POSTERIOR_FUNCS[posterior](subkeys[0], y=y_t_i)   
            theta_samples[i].append(np.array(theta_t_i).copy())

            # Sample y_t
            y_ts[i]  = model(y=None, theta=theta_t_i, rng_key=subkeys[1])  
        
    # Shape: (n_chains, n_samples, n_dim)
    return np.array(theta_samples)

### Gather samples

In [5]:
for posterior in POSTERIORS:
    theta_samples = sample_gibbs_prior(random.PRNGKey(9999), posterior, n_chains=5, T=500)
    theta_samples = np.array(theta_samples)

    np.save(f'../../results/convergence/multi_chains_{posterior}.npy', theta_samples)

100%|██████████| 500/500 [14:35<00:00,  1.75s/it]
100%|██████████| 500/500 [1:27:27<00:00, 10.49s/it]
100%|██████████| 500/500 [11:53<00:00,  1.43s/it]
100%|██████████| 500/500 [00:55<00:00,  9.09it/s]
