In [42]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import levy_stable

from numba import njit
from numba.typed import List
import tensorflow as tf

In [3]:
from bayesflow.networks import InvertibleNetwork, InvariantNetwork
from bayesflow.amortizers import SingleModelAmortizer
from bayesflow.trainers import ParameterEstimationTrainer
from bayesflow.diagnostics import *
from bayesflow.models import GenerativeModel

In [4]:
%load_ext autoreload
%autoreload 2

# Simulator settings

In [50]:
def prior(batch_size):
    """
    Samples from the prior 'batch_size' times.
    ----------
    
    Arguments:
    batch_size : int -- the number of samples to draw from the prior
    ----------
    
    Output:
    theta : np.ndarray of shape (batch_size, theta_dim) -- the samples batch of parameters
    """
    
    # Prior ranges for the simulator 
    # v_c ~ U(-7.0, 7.0)
    # a_c ~ U(0.1, 4.0)
    # t0 ~ U(0.1, 3.0)
    
    a = np.random.gamma(2,2, size=batch_size)
    zr = np.random.beta(5,5, size=batch_size)
    v1 = np.random.normal(0,5, size=batch_size)
    v2 = np.random.normal(0,5, size=batch_size)
    v3 = np.random.normal(0,5, size=batch_size)
    v4 = np.random.normal(0,5, size=batch_size)
    t0 = np.random.gamma(2,2, size=batch_size)
    alpha = 2*np.random.beta(2,1, size=batch_size)
    
    p_samples = np.c_[
        a, zr, v1, v2, v3, v4, t0, alpha
    ]
    
    return p_samples.astype(np.float32)


@njit
def diffusion_trial(v, a, ndt, zr, dt, max_steps):
    """Simulates a trial from the diffusion model."""

    n_steps = 0.
    x = a * zr

    # Simulate a single DM path
    while (x > 0 and x < a and n_steps < max_steps):

        # DDM equation
        x += v*dt + np.sqrt(dt) * np.random.normal()

        # Increment step
        n_steps += 1.0

    rt = n_steps * dt
    return rt + ndt if x > 0. else -rt - ndt



@njit
def diffusion_2_conds(params, n_trials, dt=0.005, max_steps=1e4):
    """
    Simulates a diffusion process for 2 conditions with 5 parameters (v1, v2, a1, a2, ndt).
    """
    
    n_trials_c1 = n_trials[0]
    n_trials_c2 = n_trials[1]
    
    v1, v2, a1, a2, ndt = params
    rt_c1 = diffusion_condition(n_trials_c1, v1, a1, ndt,  dt=dt, max_steps=max_steps)
    rt_c2 = diffusion_condition(n_trials_c2, v2, a2, ndt, dt=dt, max_steps=max_steps)
    rts = np.concatenate((rt_c1, rt_c2))
    return rts

@njit
def levy_trial(noise, a=1, zr=0.5, v=0, ndt=1, alpha=2, dt=0.001, max_steps=1e4):
    """Simulates a trial from the levy model."""

    n_steps = 0
    x = a * zr

    # Simulate a single DM path
    while (x > 0 and x < a and n_steps < max_steps):

        # model equation
        x += v*dt + (dt ** (1/alpha))  * noise[n_steps]

        # Increment step
        n_steps += 1

    rt = n_steps * dt
    return rt + ndt if x > 0. else -rt - ndt


@njit
def levy_condition(n_trials, noise, a=1, zr=0.5, v=0, ndt=1, alpha=2, dt=0.001, max_steps=1e4):
    """Simulates a Levy process over an entire condition."""
    
    x = np.empty(n_trials)
    for n in range(n_trials):
        x[n] = levy_trial(noise[n], a, zr, v, ndt, alpha, dt, max_steps)
    return x

@njit
def levy_4_conds(params, n_trials, noise, dt=0.001, max_steps=1e4):
    """
    Simulates a levy process for 4 conditions with 8 parameters (a, zr, v1, v2, v3, v4, t0, alpha).
    """
    
    n_trials_c1 = n_trials[0]
    n_trials_c2 = n_trials[1]
    n_trials_c3 = n_trials[2]
    n_trials_c4 = n_trials[3]
    
    a, zr, v1, v2, v3, v4, ndt, alpha = params
    rt_c1 = levy_condition(n_trials[0], noise[0], a, zr, v1, ndt, alpha, dt=dt, max_steps=max_steps)
    rt_c2 = levy_condition(n_trials[1], noise[1], a, zr, v2, ndt, alpha, dt=dt, max_steps=max_steps)
    rt_c3 = levy_condition(n_trials[2], noise[2], a, zr, v3, ndt, alpha, dt=dt, max_steps=max_steps)
    rt_c4 = levy_condition(n_trials[3], noise[3], a, zr, v4, ndt, alpha, dt=dt, max_steps=max_steps)

    rts = np.concatenate((rt_c1, rt_c2, rt_c3, rt_c4))
    return rts


def batch_simulator(prior_samples, n_obs, dt=0.001, s=1.0, max_steps=1e4):
    """
    Simulate multiple diffusion_model_datasets.
    """
    
    n_sim = prior_samples.shape[0]
    sim_data = np.zeros((n_sim, n_obs), dtype=np.float32)
    
    n1 = n2 = n3 = n_obs // 4
    n4 = n_obs - 3 * n1
    n_obs_tuple = (n1, n2, n3, n4)
    
    # Simulate diffusion data
    for i in range(n_sim):
        
        # Precompute noise, shape of each array in the list will be (n_obs_cond, max_steps)
        noise = List([
            levy_stable.rvs(alpha=prior_samples[i,-1], beta=0, size=(n, int(max_steps)))
            for n in n_obs_tuple 
        ])
        
        # Simulate data
        sim_data[i] = levy_4_conds(prior_samples[i], n_obs_tuple, noise)
        
    # Create condition labels
    cond_arr = np.stack(n_sim * [np.concatenate((np.zeros(n1), np.ones(n2), 2*np.ones(n3), 3*np.ones(n3)))] )
    sim_data = np.stack((sim_data, cond_arr), axis=-1)
    
    return sim_data

In [53]:
# %%time
# params = prior(10)
# batch_simulator(params, 100)

# An example Bayesian workflow (with BayesFlow)

Towards a principled Bayesian workflow for cognitive modeling:

https://betanalpha.github.io/assets/case_studies/principled_bayesian_workflow.html

https://arxiv.org/abs/1904.12765

## Prior predictive checks

In [17]:
# Add stuff

## Train an amortized parameter estimation network
Here, we use an invariant summary network and an invertible inference network with default settings.

We connect the networks through a *SingleModelAmortizer* instance.

In [55]:
summary_net = InvariantNetwork()
inference_net = InvertibleNetwork({'n_params': 8})
amortizer = SingleModelAmortizer(inference_net, summary_net)

We connect the prior and simulator through a *GenerativeModel* class which will take care of forward inference.

In [56]:
generative_model = GenerativeModel(prior, batch_simulator)

In [None]:
trainer = ParameterEstimationTrainer(
    network=amortizer, 
    generative_model=generative_model,
)

### Offline training

#### Using pre-simulated data

In [None]:
# Pre-simulated data (could be loaded from somewhere else)
n_sim = 5000
n_obs = 100
true_params, x = generative_model(n_sim, n_obs)

In [None]:
%%time
losses = trainer.train_offline(epochs=1, batch_size=64, params=true_params, sim_data=x)

#### Using internally simulated data

In [None]:
%%time
losses = trainer.simulate_and_train_offline(n_sim=1000, epochs=2, batch_size=32, n_obs=n_obs)

### Online training

In [None]:
# Fixed n_obs


In [None]:
%%time
losses = trainer.train_online(epochs=2, iterations_per_epoch=100, batch_size=32, n_obs=n_obs)

In [None]:
# Variable n_obs
def prior_N(n_min=60, n_max=300):
    """
    A prior or the number of observation (will be called internally at each backprop step).
    """
    
    return np.random.randint(n_min, n_max+1)

In [None]:
%%time
losses = trainer.train_online(epochs=2, iterations_per_epoch=100, batch_size=32, n_obs=prior_N)

### Round-based training

In [None]:
%%time
losses = trainer.train_rounds(epochs=1, rounds=5, sim_per_round=200, batch_size=32, n_obs=n_obs)

### Experience-replay training

In [None]:
%%time
losses = trainer.train_experience_replay(epochs=3, 
                                         batch_size=32, 
                                         iterations_per_epoch=100, 
                                         capacity=100,
                                         n_obs=prior_N)

## Custom networks

In [None]:
sum_meta = {
    'n_dense_s1': 2,
    'n_dense_s2': 2,
    'n_dense_s3': 2,
    'n_equiv':    2,
    'dense_s1_args': {'activation': 'relu', 'units': 64},
    'dense_s2_args': {'activation': 'relu', 'units': 64},
    'dense_s3_args': {'activation': 'relu', 'units': 64},
}

bf_meta = {
    'n_coupling_layers': 4,
    's_args': {
        'units': [64, 64, 64],
        'activation': 'elu',
        'initializer': 'glorot_uniform',
    },
    't_args': {
        'units': [64, 64, 64],
        'activation': 'elu',
        'initializer': 'glorot_uniform',
    },
    'n_params': 5,
    'alpha': 1.9,
    'permute': True
}

In [None]:
summary_net = InvariantNetwork(sum_meta)
inference_net = InvertibleNetwork(bf_meta)
amortizer = SingleModelAmortizer(inference_net, summary_net)

## Compuational faithfulness
(Via simulation-based calibration)

In [None]:
n_sbc = 5000
n_post_samples_sbc = 250
params_sbc = prior(n_sbc)
x_sbc = batch_simulator(params_sbc, 100)
param_samples = np.concatenate([amortizer.sample(x, n_post_samples_sbc) 
                                for x in tf.split(x_sbc, 10, axis=0)], axis=1)

In [None]:
f = plot_sbc(param_samples, params_sbc, param_names=['v1', 'v2', 'a1', 'a2', 'ndt'])

## Model sensitivity/adequacy

### Quick and dirty

In [None]:
# Validate (quick and dirty)
true_params = prior(300)
x = batch_simulator(true_params).astype(np.float32)
param_samples = amortizer.sample(x, n_samples=1000)
param_means = param_samples.mean(axis=0)
true_vs_estimated(true_params, param_means, ['v1', 'v2', 'a1', 'a2','ndt'])

### A Bayesian eyechart

In [None]:
# Simulate
param_names = ['v1', 'v2', 'a1', 'a2','ndt']
n_sim_s = 500
n_samples_posterior = 1000
true_params = prior(n_sim_s)
x = batch_simulator(true_params)

# Sample from posterior
param_samples = amortizer.sample(x, n_samples_posterior)

### Posterior z-score
# Compute posterior means and stds
post_means = param_samples.mean(0)
post_stds = param_samples.std(0)
post_vars = param_samples.var(0)

# Compute posterior z score
post_z_score = (post_means - true_params) / post_stds

### Posterior contraction, i.e., 1 - post_var / prior_var
prior_a = (0.1, 0.1, 0.1, 0.1, 0.1) # lower bound of uniform prior
prior_b = (7.0, 7.0, 4.0, 4.0, 3.0) # upper bound of uniform prior

# Compute prior vars analytically
prior_vars = np.array([(b-a)**2/12 for a,b in zip(prior_a, prior_b)])
post_cont = 1 - post_vars / prior_vars

# Plotting time
f, axarr = plt.subplots(2, 3, figsize=(12, 6))
for i, (p, ax) in enumerate(zip(param_names, axarr.flat)):
    

    ax.scatter(post_cont[:, i], post_z_score[:, i], color='#8f2727', alpha=0.7)
    ax.set_title(p, fontsize=20)
    sns.despine(ax=ax)
    ax.set_xlim([-0.1, 1.05])
    ax.set_ylim([-3.5, 3.5])
    ax.grid(color='black', alpha=0.1)
    ax.set_xlabel('Posterior contraction', fontsize=14)
    if i == 0 or i == 3:
        ax.set_ylabel('Posterior z-score', fontsize=14)
f.tight_layout()

## Posterior postdictive/predictive checks