In [1]:
import warnings
warnings.filterwarnings('ignore')

# %env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%load_ext autoreload
%autoreload 2
import os
import numpy as np
from jax import random
import jax.numpy as jnp
import numpyro as npr
import tqdm as tqdm
from math import *
import numpyro.distributions as dist


npr.set_platform('gpu')

### Models

#### True model

In [2]:
L = 10

def model_true(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 5)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # The true likelihood, sum of LogNormal rvs.
    with npr.plate('L', L):
        x = npr.sample('X', dist.LogNormal(mu, jnp.sqrt(sigma_sq)), rng_key=subkeys[2])
        
    out = npr.sample('Y', dist.Delta(x.sum(0)), rng_key=subkeys[3], obs=y)
        
    if theta is None and y is not None:
        return (mu, sigma_sq)
    else:
        return out
    
    
# Test
model_true()

DeviceArray(7.4827366, dtype=float32)

#### Approximate model

In [3]:
def model_abc(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 4)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # Approximate likelihood
    beta_sq = jnp.log((jnp.exp(sigma_sq)-1)/L + 1)
    alpha = mu + jnp.log(L) + 0.5*(sigma_sq - beta_sq)
        
    out = npr.sample('Y', dist.LogNormal(alpha, jnp.sqrt(beta_sq)), rng_key=subkeys[2], obs=y)
    
    if theta is None and y is not None:
        return (mu, sigma_sq)
    else:
        return out
    
    
# Test
model_abc()

DeviceArray(0.8266345, dtype=float32)

### Methods

In [4]:
from numpyro.infer import SVI, Trace_ELBO, MCMC, NUTS
from numpyro.infer.autoguide import *
from jax import jit


@jit
def laplace(rng_key, y):    
    key, *subkeys = random.split(rng_key, 4)
    
    guide = AutoLaplaceApproximation(model_abc)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model_abc, guide, optimizer, loss=Trace_ELBO(num_particles=1))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=False)
    
    mu = guide.sample_posterior(subkeys[1], svi_result.params)['mu']
    sigma_sq = guide.sample_posterior(subkeys[2], svi_result.params)['sigma_sq']

    return mu, sigma_sq

### Gibbs-prior sampler

In [5]:
def sample_gibbs_prior(rng_key, T=10000):    
    theta_samples = []
    
    rng_key, rng_subkey = random.split(rng_key)
    y_t  = model_true(y=None, theta=None, rng_key=rng_subkey)
    
    pbar = tqdm.trange(T)
    for t in pbar:     
        rng_key, *subkeys = random.split(rng_key, 3)
        
        # Get q(theta | y_t)
        theta_t = laplace(subkeys[0], y_t)
        theta_samples.append(np.array(theta_t).copy())

        # Sample y_t, always using the true model
        y_t  = model_true(y=None, theta=theta_t, rng_key=subkeys[1])
        
        # Sometimes y_t == inf due to precision overflow
        # In this case, just sample iid again
        while y_t == inf:
            rng_key, subkey = random.split(rng_key)
            y_t  = model_true(y=None, theta=theta_t, rng_key=subkey)
        
        if t % 100 == 0:
            pbar.set_description(f'[y_t: {y_t:.3f}, theta_t: {np.array(theta_t)}]')
        
    return np.array(theta_samples)

### Sample $\pi_G$ for the ABC model

In [6]:
rng_key = random.PRNGKey(1234)

# Laplace
thetas_laplace = sample_gibbs_prior(rng_key, T=10000)
np.save('../../results/log_normal/laplace.npy', thetas_laplace)

[y_t: 13.843, theta_t: [0.05588176 0.18867423]]: 100%|██████████| 10000/10000 [15:08<00:00, 11.00it/s]


### Sampling from the prior

In [7]:
def sample_prior(n_samples):
    rng_keys = random.split(random.PRNGKey(1), 2)
    mu = dist.Normal(0, 1).sample(rng_keys[0], (n_samples, 1))
    sigma_sq = dist.Gamma(1, 1).sample(rng_keys[1], (n_samples, 1))
    theta = np.concatenate((mu, sigma_sq), axis=-1)
    return theta

thetas_prior = sample_prior(n_samples=10000)
np.save('../../results/log_normal/prior_log_normal.npy', thetas_prior)