In [1]:
%env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import numpyro as npr
import matplotlib
import matplotlib.cm as cm
import matplotlib.dates as mdates
import tqdm
from math import *
import numpyro.distributions as dist
import seaborn as sns
sns.set_style('white')

matplotlib.rcParams['figure.figsize'] = (8,5)
matplotlib.rcParams['font.size'] = 10
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['font.serif'] = 'Times'
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['lines.linewidth'] = 1
plt = matplotlib.pyplot

npr.set_platform('gpu')

env: XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


### Model

In [2]:
# Hyperparams
sigma = 0.09
nu = 12
NUM_STEPS = 100


def model(y, theta=None, rng_key=random.PRNGKey(1)):
    """ PyMC3 example http://num.pyro.ai/en/0.6.0/examples/stochastic_volatility.html """
    rng_key, rng_subkey = random.split(rng_key)

    num_steps = len(y) if y is not None else NUM_STEPS
    
    if theta is None:
        log_vol = npr.sample(
            'theta', dist.GaussianRandomWalk(scale=sigma, num_steps=num_steps), rng_key=rng_subkey
        )
    else:
        log_vol = theta
    
    rng_key, rng_subkey = random.split(rng_key)
    returns = npr.sample('y', dist.StudentT(df=nu, loc=0., scale=jnp.exp(log_vol)),
                         rng_key=rng_subkey, obs=y)
    
    if theta is None and y is None:  
        return log_vol  # Sample latent
    else:  
        return returns  # Given latent, sample y

### Methods

In [3]:
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import *
from jax import jit


@jit
def vb_diag(rng_key, y):    
    key, *subkeys = random.split(rng_key, 3)
    
    guide = AutoDiagonalNormal(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(subkeys[0], n_iter, y=y, progress_bar=False)
    
    theta = guide.sample_posterior(subkeys[1], svi_result.params)['theta']
    return theta

@jit
def vb_full(rng_key, y):    
    key, *subkeys = random.split(rng_key, 3)
    
    guide = AutoMultivariateNormal(model)
    lr = 5e-4  # Unstable with large lr
    n_iter = 10000  # Compensate with larger num. of iterations
    
    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=100))
    svi_result = svi.run(subkeys[0], n_iter, y=y, progress_bar=False)
    
    theta = guide.sample_posterior(subkeys[1], svi_result.params)['theta']
    return theta

@jit
def mcmc_short(rng_key, y):    
    mcmc = MCMC(NUTS(model), num_warmup=5, num_samples=5, progress_bar=False)
    mcmc.run(rng_key, y)

    theta = mcmc.get_samples()['theta'][-1]
    return theta


@jit
def mcmc_long(rng_key, y):    
    mcmc = MCMC(NUTS(model), num_warmup=20, num_samples=20, progress_bar=False)
    mcmc.run(rng_key, y)

    theta = mcmc.get_samples()['theta'][-1]
    return theta

### Sampling $\pi_G$

In [4]:
POSTERIORS = ['vb_diag', 'vb_full', 'mcmc_long', 'mcmc_short']
POSTERIOR_FUNCS = {
    'vb_diag': vb_diag, 
    'vb_full': vb_full, 
    'mcmc_short': mcmc_short,
    'mcmc_long': mcmc_long
}


def sample_gibbs_prior(rng_key, posterior, T=10000):
    assert posterior in POSTERIORS
    
    theta_samples = []
    
    rng_key, rng_subkey = random.split(rng_key)
    y_t  = model(y=None, rng_key=rng_subkey)
    
    pbar = tqdm.trange(T)
    for t in pbar:     
        rng_key, *subkeys = random.split(rng_key, 5)
        
        # Get q(theta | y_t)
        theta_t = POSTERIOR_FUNCS[posterior](subkeys[0], y=y_t)
        theta_samples.append(np.array(theta_t).copy())
        
        # Sample y_t
        y_t  = model(y=None, theta=theta_t, rng_key=subkeys[3])
        
        if t % 100 == 0:
            pbar.set_description(f'[mean y_t: {y_t.mean():.3f}, mean theta_t: {np.mean(theta_t):.3f}]')
        
    return theta_samples

#### Get samples

In [5]:
for post in POSTERIORS:
    thetas = sample_gibbs_prior(random.PRNGKey(9999), post, T=10000)
    np.save(f'../../results/volatility/{post}.npy', thetas)

[mean y_t: 0.213, mean theta_t: 0.128]: 100%|██████████| 10000/10000 [58:46<00:00,  2.84it/s] 
[mean y_t: 0.260, mean theta_t: 0.334]: 100%|██████████| 10000/10000 [5:49:55<00:00,  2.10s/it]   
[mean y_t: 0.828, mean theta_t: 0.990]: 100%|██████████| 10000/10000 [1:04:16<00:00,  2.59it/s] 
[mean y_t: 0.235, mean theta_t: 0.159]: 100%|██████████| 10000/10000 [03:30<00:00, 47.40it/s] 


### Prior

In [6]:
rng_key = random.PRNGKey(42)
rng_key, rng_subkey = random.split(rng_key)
thetas_prior = dist.GaussianRandomWalk(scale=sigma, num_steps=NUM_STEPS).sample(rng_subkey, (10000,))
np.save(f'../../results/volatility/prior_volatility.npy', thetas_prior)