In [14]:
import warnings
warnings.filterwarnings('ignore')

# %env XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda/
%load_ext autoreload
%autoreload 2
import os
import numpy as np
from jax import random
import jax.numpy as jnp
import numpyro as npr
import tqdm as tqdm
from math import *
import numpyro.distributions as dist


npr.set_platform('gpu')

### Models

#### True model (sample from exact likelihood)

In [27]:
L = 10

def model_true(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 5)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # The true likelihood, sum of LogNormal rvs.
    with npr.plate('L', L):
        x = npr.sample('X', dist.LogNormal(mu, jnp.sqrt(sigma_sq)), rng_key=subkeys[2])
        
    out = npr.sample('Y', dist.Delta(x.sum(0)), rng_key=subkeys[3], obs=y)
        
    if theta is None and y is not None:
        return (mu, sigma_sq)
    else:
        return out
    
    
# Test
model_true()

DeviceArray(7.4827366, dtype=float32)

#### Approximate model

In [16]:
def model_abc(y=None, theta=None, rng_key=random.PRNGKey(1)):
    key, *subkeys = random.split(rng_key, 4)  # subkeys
    
    if theta is None:
        # Sample from priors \pi(\mu, \sigma^2)
        mu = npr.sample('mu', dist.Normal(0, 1), rng_key=subkeys[0])
        sigma_sq = npr.sample('sigma_sq', dist.Gamma(1, 1), rng_key=subkeys[1])
    else:
        mu, sigma_sq = theta
    
    # Approximate likelihood
    beta_sq = jnp.log((jnp.exp(sigma_sq)-1)/L + 1)
    alpha = mu + jnp.log(L) + 0.5*(sigma_sq - beta_sq)
        
    out = npr.sample('Y', dist.LogNormal(alpha, jnp.sqrt(beta_sq)), rng_key=subkeys[2], obs=y)
    
    if theta is None and y is not None:
        return (mu, sigma_sq)
    else:
        return out
    
    
# Test
model_abc()

DeviceArray(0.8266345, dtype=float32)

### Methods

In [17]:
from numpyro.infer import SVI, Trace_ELBO, MCMC, NUTS
from numpyro.infer.autoguide import *


def laplace(rng_key, model, y, n_return_samples=1, pbar=False):    
    key, *subkeys = random.split(rng_key, 4)
    
    guide = AutoLaplaceApproximation(model)
    lr = 1e-3
    n_iter = 5000

    optimizer = npr.optim.ClippedAdam(step_size=lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=1))
    svi_result = svi.run(rng_key, n_iter, y=y, progress_bar=pbar)
    
    if n_return_samples == 1:
        mu = guide.sample_posterior(subkeys[1], svi_result.params)['mu']
        sigma_sq = guide.sample_posterior(subkeys[2], svi_result.params)['sigma_sq']

        return mu, sigma_sq
    else:
        thetas = guide.sample_posterior(subkeys[1], svi_result.params, sample_shape=(n_return_samples,))
        return np.array(thetas['mu']), np.array(thetas['sigma_sq'])

### Baseline (Talts et al.)

In [32]:
M = 1e4
R = 31 
N = round(10.42*R)

print(f'M = {M}, R = {R}, N = {N}, N*R = {N*R}')


def run_baseline(rng_key=random.PRNGKey(1234)):    
    results = np.zeros([N, R+1, 2])
    
    for n in tqdm.trange(N):
        rng_key, *subkeys = random.split(rng_key, 5)
        
        mu_prior = dist.Normal(0, 1).rsample(subkeys[0])
        sigma_sq_prior = dist.Gamma(1, 1).rsample(subkeys[1])
        results[n, 0, :] = np.array([mu_prior, sigma_sq_prior])  # (2,)
        
        y  = model_true(y=None, theta=(mu_prior, sigma_sq_prior), rng_key=subkeys[2])
        mus, sigma_sqs = laplace(subkeys[3], model_abc, y, n_return_samples=R, pbar=False)
        results[n, 1:, :] = np.stack([mus, sigma_sqs]).T  # (R, 2)
                
    return results
        
results = run_baseline()
np.save('../../results/baseline/baseline_log_normal.npy', results)

M = 10000.0, R = 31, N = 323, N*R = 10013


100%|██████████████████████████████████████████████████████████████████████| 323/323 [02:57<00:00,  1.82it/s]
