In [1]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import numpyro as npr
import matplotlib
import matplotlib.cm as cm
import matplotlib.dates as mdates
import tqdm
from math import *
import numpyro.distributions as dist
import seaborn as sns
sns.set_style('white')

matplotlib.rcParams['figure.figsize'] = (8,5)
matplotlib.rcParams['font.size'] = 10
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['font.serif'] = 'Times'
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['lines.linewidth'] = 1
plt = matplotlib.pyplot

npr.set_platform('gpu')

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


### Model

In [2]:
# Hyperparams
sigma = 0.09
nu = 12
NUM_STEPS = 100


def model(y, theta=None, rng_key=random.PRNGKey(1)):
    """ PyMC3 example http://num.pyro.ai/en/0.6.0/examples/stochastic_volatility.html """
    rng_key, rng_subkey = random.split(rng_key)

    num_steps = len(y) if y is not None else NUM_STEPS
    
    if theta is None:
        log_vol = npr.sample(
            'theta', dist.GaussianRandomWalk(scale=sigma, num_steps=num_steps), rng_key=rng_subkey
        )
    else:
        log_vol = theta
    
    rng_key, rng_subkey = random.split(rng_key)
    returns = npr.sample('y', dist.StudentT(df=nu, loc=0., scale=jnp.exp(log_vol)),
                         rng_key=rng_subkey, obs=y)
    
    if theta is None and y is None:  
        return log_vol  # Sample latent
    else:  
        return returns  # Given latent, sample y

#### VB

In [3]:
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import *
from jax import jit


M = 1e4
R = 31 
N = round(10.42*R)

print(f'M = {M}, R = {R}, N = {N}, N*R = {N*R}')


@jit
def vb_diag(rng_key, y):    
    key, *subkeys = random.split(rng_key, 3)
    
    guide = AutoDiagonalNormal(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(subkeys[0], n_iter, y=y, progress_bar=False)
    
    theta = guide.sample_posterior(subkeys[1], svi_result.params, sample_shape=(R,))['theta']
    return theta

@jit
def vb_full(rng_key, y):    
    key, *subkeys = random.split(rng_key, 3)
    
    guide = AutoMultivariateNormal(model)
    lr = 5e-4  # Unstable with large lr
    n_iter = 10000  # Compensate with larger num. of iterations
    
    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(subkeys[0], n_iter, y=y, progress_bar=False)
    
    theta = guide.sample_posterior(subkeys[1], svi_result.params, sample_shape=(R,))['theta']
    return theta

@jit
def mcmc_short(rng_key, y):    
    mcmc = MCMC(NUTS(model), num_warmup=5, num_samples=5, progress_bar=False,
                num_chains=R, chain_method='vectorized')
    mcmc.run(rng_key, y)

    # Last sample from each R chains
    theta = mcmc.get_samples(group_by_chain=True)['theta'][:, -1, :]
    
    return theta


@jit
def mcmc_long(rng_key, y):    
    mcmc = MCMC(NUTS(model), num_warmup=20, num_samples=20, progress_bar=False,
                num_chains=R, chain_method='vectorized')
    mcmc.run(rng_key, y)

    # Last sample from each R chains
    theta = mcmc.get_samples(group_by_chain=True)['theta'][:, -1, :]
    
    return theta

M = 10000.0, R = 31, N = 323, N*R = 10013


In [4]:
POSTERIORS = ['vb_diag', 'vb_full', 'mcmc_long', 'mcmc_short']
POSTERIOR_FUNCS = {
    'vb_diag': vb_diag, 
    'vb_full': vb_full, 
    'mcmc_short': mcmc_short,
    'mcmc_long': mcmc_long
}

### Baseline (Talts et al.)

In [5]:
def run_baseline(posterior, rng_key=random.PRNGKey(1234)):    
    results = np.zeros([N, R+1, NUM_STEPS])
    
    for n in tqdm.trange(N):
        rng_key, *subkeys = random.split(rng_key, 4)
        
        # Get theta ~ p(theta)
        theta_prior = dist.GaussianRandomWalk(scale=sigma, num_steps=NUM_STEPS).rsample(subkeys[0])
        results[n, 0, :] = theta_prior  # (NUM_STEPS,)
        
        # Get theta ~ q(theta | y)
        y  = model(y=None, theta=theta_prior, rng_key=subkeys[1])
        theta = POSTERIOR_FUNCS[posterior](subkeys[2], y=y)
        results[n, 1:, :] = np.array(theta).squeeze()  # (R, NUM_STEPS)
                
    return results

    
for posterior in POSTERIORS:    
    results = run_baseline(posterior=posterior)
    np.save(f'../../results/baseline/baseline_{posterior}.npy', results)

100%|██████████| 323/323 [01:51<00:00,  2.89it/s]
100%|██████████| 323/323 [11:21<00:00,  2.11s/it]
100%|██████████| 323/323 [07:33<00:00,  1.40s/it]
100%|██████████| 323/323 [00:23<00:00, 14.00it/s]
