In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from functools import partial

from numba import njit
import tensorflow as tf
from tensorflow.keras.layers import LSTM
from tensorflow.keras.models import Sequential

In [None]:
from bayesflow.networks import EvidentialNetwork
from bayesflow.trainers import MultiModelTrainer
from bayesflow.losses import log_loss
from bayesflow.diagnostics import plot_confusion_matrix, plot_calibration_curves, expected_calibration_error

In [None]:
%load_ext autoreload
%autoreload 2



This notebook contains an example simulation-based model comparison workflow.

# Simulator settings

## Model prior
Implements sampling from $p(\mathcal{M})$.

In [None]:
def model_prior(batch_size, p_vals=None):
    """
    Samples from the models' prior batch size times and converts to one-hot.
    Assumes equal model priors.
    ----------
    
    Arguments:
    batch_size : int  -- the number of samples to draw from the prior
    ----------
    
    Returns:
    m_true : np.ndarray of shape (batch_size, theta_dim) -- the samples batch of parameters
    """
    
    # Equal priors, if nothign specified
    if p_vals is None:
        p_vals = [1/3] * 3
    m_idx = np.random.choice(3, size=batch_size, p=p_vals).astype(np.int32)
    return m_idx

## Parameter priors
Implements sampling from each $p(\theta_j\,|\,\mathcal{M}_j)$.

In [None]:
def model1_params_prior(**args):
    """
    Samples from the prior of the HH-2pars theta = (gbar_Na,gbar_K)
    ----------
    
    Arguments:
    ----------
    
    Output:
    theta : np.ndarray of shape (1, theta_dim) -- the samples of parameters
            or a dict with param key-values
    """
    
    theta = [
         np.random.uniform(low=1.5, high=30), 
         np.random.uniform(low=0.3, high=15)
    ]
    return np.array(theta)


def model2_params_prior(**args):
    """
     Samples from the prior of the HH-3pars theta = (gbar_Na,gbar_K,gbar_M)
    ----------
    
    Arguments:
    ----------
    
    Output:
    theta : np.ndarray of shape (1, theta_dim) -- the samples of parameters
            or a dict with param key-values
    """
    
    theta = [
        np.random.uniform(low=1.5, high=30), 
        np.random.uniform(low=0.3, high=15),
        np.random.uniform(low=0.005, high=0.3)  
    ]
    return np.array(theta)


def model3_params_prior(**args):
    """
    Samples from the prior of the HH-4pars theta = (gbar_l,gbar_Na,gbar_K,gbar_M)
    ----------
    
    Arguments:
    ----------
    
    Output:
    theta : np.ndarray of shape (1, theta_dim) -- the samples of parameters
            or a dict with param key-values
    """
    
    theta = [
        np.random.uniform(low=0.01, high=0.18),
        np.random.uniform(low=1.5, high=30), 
        np.random.uniform(low=0.1, high=15),
        np.random.uniform(low=0.005, high=0.3)
    ]
    return np.array(theta)

### Simulators
Implements each forward model (stochastic simulator) $g_j(\theta_j,\xi)$. Uses $numba$ for just-in-time compilation (i.e., speed).

In [None]:
@njit
def forward_model1(params, n_obs, V0=-70, I_input=3, dt=0.2):
    
    # HH-2pars    

    # pars = [gbar_Na, gbar_K]
    # I_input = input current in muA/cm2
    # I_duration = duration of current input in ms
    # dt = dt
    
    I_duration = n_obs
    gbar_Na, gbar_K = params


    # fixed parameters
    tau_max = 6e2   # ms
    Vt = -60.       # mV
    nois_fact = 0.1 # uA/cm2
    E_leak = -70.   # mV
    E_Na = 53       # mV
    E_K = -107      # mV
    C = 1
    g_l = 0.1
    gbar_M = 0.07

    tstep = float(dt)
    
    ####################################
    # Current (I) muA/cm2
    t_on = 10
    t_off = I_duration + 10
    t = np.arange(0, t_on+t_off+dt, dt)
    I = np.zeros_like(t)
    I[int(np.round(t_on/dt)):int(np.round(t_off/dt))] = I_input

    ####################################
    # kinetics
    def efun(z):
        if np.abs(z) < 1e-4:
            return 1 - z/2
        else:
            return z / (np.exp(z) - 1)

    def alpha_m(x):
        v1 = x - Vt - 13.
        return 0.32*efun(-0.25*v1)/0.25

    def beta_m(x):
        v1 = x - Vt - 40
        return 0.28*efun(0.2*v1)/0.2

    def alpha_h(x):
        v1 = x - Vt - 17.
        return 0.128*np.exp(-v1/18.)

    def beta_h(x):
        v1 = x - Vt - 40.
        return 4.0/(1 + np.exp(-0.2*v1))

    def alpha_n(x):
        v1 = x - Vt - 15.
        return 0.032*efun(-0.2*v1)/0.2

    def beta_n(x):
        v1 = x - Vt - 10.
        return 0.5*np.exp(-v1/40)

    # steady-states and time constants
    def tau_n(x):
         return 1/(alpha_n(x) + beta_n(x))
    def n_inf(x):
        return alpha_n(x)/(alpha_n(x) + beta_n(x))
    def tau_m(x):
        return 1/(alpha_m(x) + beta_m(x))
    def m_inf(x):
        return alpha_m(x)/(alpha_m(x) + beta_m(x))
    def tau_h(x):
        return 1/(alpha_h(x) + beta_h(x))
    def h_inf(x):
        return alpha_h(x)/(alpha_h(x) + beta_h(x))

    # slow non-inactivating K+
    def p_inf(x):
        v1 = x + 35.
        return 1.0/(1. + np.exp(-0.1*v1))

    def tau_p(x):
        v1 = x + 35.
        return tau_max/(3.3*np.exp(0.05*v1) + np.exp(-0.05*v1))


    ####################################
    # simulation from initial point
    V = np.zeros_like(t) # voltage
    n = np.zeros_like(t)
    m = np.zeros_like(t)
    h = np.zeros_like(t)
    p = np.zeros_like(t)

    V[0] = float(V0)
    n[0] = n_inf(V[0])
    m[0] = m_inf(V[0])
    h[0] = h_inf(V[0])
    p[0] = p_inf(V[0])

    for i in range(1, t.shape[0]):
        tau_V_inv = ( (m[i-1]**3)*gbar_Na*h[i-1]+(n[i-1]**4)*gbar_K+g_l+gbar_M*p[i-1] )/C
        V_inf = ( (m[i-1]**3)*gbar_Na*h[i-1]*E_Na+(n[i-1]**4)*gbar_K*E_K+g_l*E_leak+gbar_M*p[i-1]*E_K
                +I[i-1]+nois_fact*np.random.randn()/(tstep**0.5) )/(tau_V_inv*C)
        V[i] = V_inf + (V[i-1]-V_inf)*np.exp(-tstep*tau_V_inv)
        n[i] = n_inf(V[i])+(n[i-1]-n_inf(V[i]))*np.exp(-tstep/tau_n(V[i]))
        m[i] = m_inf(V[i])+(m[i-1]-m_inf(V[i]))*np.exp(-tstep/tau_m(V[i]))
        h[i] = h_inf(V[i])+(h[i-1]-h_inf(V[i]))*np.exp(-tstep/tau_h(V[i]))
        p[i] = p_inf(V[i])+(p[i-1]-p_inf(V[i]))*np.exp(-tstep/tau_p(V[i]))

    return np.expand_dims(V, -1)


@njit
def forward_model2(params, n_obs, V0=-70, I_input=3, dt=0.2):
    
    # HH-3pars    

    # pars = [gbar_Na, gbar_K, gbar_M]
    # I_input = input current in muA/cm2
    # I_duration = duration of current input in ms
    # dt = dt
    
    I_duration = n_obs
    gbar_Na, gbar_K, gbar_M = params


    # fixed parameters
    tau_max = 6e2   # ms
    Vt = -60.       # mV
    nois_fact = 0.1 # uA/cm2
    E_leak = -70.   # mV
    E_Na = 53       # mV
    E_K = -107      # mV
    C = 1
    g_l = 0.1

    tstep = float(dt)
    
    ####################################
    # Current (I) muA/cm2
    t_on = 10
    t_off = I_duration + 10
    t = np.arange(0, t_on+t_off+dt, dt)
    I = np.zeros_like(t)
    I[int(np.round(t_on/dt)):int(np.round(t_off/dt))] = I_input

    ####################################
    # kinetics
    def efun(z):
        if np.abs(z) < 1e-4:
            return 1 - z/2
        else:
            return z / (np.exp(z) - 1)

    def alpha_m(x):
        v1 = x - Vt - 13.
        return 0.32*efun(-0.25*v1)/0.25

    def beta_m(x):
        v1 = x - Vt - 40
        return 0.28*efun(0.2*v1)/0.2

    def alpha_h(x):
        v1 = x - Vt - 17.
        return 0.128*np.exp(-v1/18.)

    def beta_h(x):
        v1 = x - Vt - 40.
        return 4.0/(1 + np.exp(-0.2*v1))

    def alpha_n(x):
        v1 = x - Vt - 15.
        return 0.032*efun(-0.2*v1)/0.2

    def beta_n(x):
        v1 = x - Vt - 10.
        return 0.5*np.exp(-v1/40)

    # steady-states and time constants
    def tau_n(x):
         return 1/(alpha_n(x) + beta_n(x))
    def n_inf(x):
        return alpha_n(x)/(alpha_n(x) + beta_n(x))
    def tau_m(x):
        return 1/(alpha_m(x) + beta_m(x))
    def m_inf(x):
        return alpha_m(x)/(alpha_m(x) + beta_m(x))
    def tau_h(x):
        return 1/(alpha_h(x) + beta_h(x))
    def h_inf(x):
        return alpha_h(x)/(alpha_h(x) + beta_h(x))

    # slow non-inactivating K+
    def p_inf(x):
        v1 = x + 35.
        return 1.0/(1. + np.exp(-0.1*v1))

    def tau_p(x):
        v1 = x + 35.
        return tau_max/(3.3*np.exp(0.05*v1) + np.exp(-0.05*v1))


    ####################################
    # simulation from initial point
    V = np.zeros_like(t) # voltage
    n = np.zeros_like(t)
    m = np.zeros_like(t)
    h = np.zeros_like(t)
    p = np.zeros_like(t)

    V[0] = float(V0)
    n[0] = n_inf(V[0])
    m[0] = m_inf(V[0])
    h[0] = h_inf(V[0])
    p[0] = p_inf(V[0])

    for i in range(1, t.shape[0]):
        tau_V_inv = ( (m[i-1]**3)*gbar_Na*h[i-1]+(n[i-1]**4)*gbar_K+g_l+gbar_M*p[i-1] )/C
        V_inf = ( (m[i-1]**3)*gbar_Na*h[i-1]*E_Na+(n[i-1]**4)*gbar_K*E_K+g_l*E_leak+gbar_M*p[i-1]*E_K
                +I[i-1]+nois_fact*np.random.randn()/(tstep**0.5) )/(tau_V_inv*C)
        V[i] = V_inf + (V[i-1]-V_inf)*np.exp(-tstep*tau_V_inv)
        n[i] = n_inf(V[i])+(n[i-1]-n_inf(V[i]))*np.exp(-tstep/tau_n(V[i]))
        m[i] = m_inf(V[i])+(m[i-1]-m_inf(V[i]))*np.exp(-tstep/tau_m(V[i]))
        h[i] = h_inf(V[i])+(h[i-1]-h_inf(V[i]))*np.exp(-tstep/tau_h(V[i]))
        p[i] = p_inf(V[i])+(p[i-1]-p_inf(V[i]))*np.exp(-tstep/tau_p(V[i]))

    return np.expand_dims(V, -1)


@njit
def forward_model3(params, n_obs, V0=-70, I_input=3, dt=0.2):
    
    # HH-4pars    

    # pars = [gbar_l, gbar_Na, gbar_K, gbar_M]
    # I_input = input current in muA/cm2
    # I_duration = duration of current input in ms
    # dt = dt
    
    I_duration = n_obs
    g_l, gbar_Na, gbar_K, gbar_M = params


    # fixed parameters
    tau_max = 6e2   # ms
    Vt = -60.       # mV
    nois_fact = 0.1 # uA/cm2
    E_leak = -70.   # mV
    E_Na = 53       # mV
    E_K = -107      # mV
    C = 1

    tstep = float(dt)
    
    ####################################
    # Current (I) muA/cm2
    t_on = 10
    t_off = I_duration + 10
    t = np.arange(0, t_on+t_off+dt, dt)
    I = np.zeros_like(t)
    I[int(np.round(t_on/dt)):int(np.round(t_off/dt))] = I_input

    ####################################
    # kinetics
    def efun(z):
        if np.abs(z) < 1e-4:
            return 1 - z/2
        else:
            return z / (np.exp(z) - 1)

    def alpha_m(x):
        v1 = x - Vt - 13.
        return 0.32*efun(-0.25*v1)/0.25

    def beta_m(x):
        v1 = x - Vt - 40
        return 0.28*efun(0.2*v1)/0.2

    def alpha_h(x):
        v1 = x - Vt - 17.
        return 0.128*np.exp(-v1/18.)

    def beta_h(x):
        v1 = x - Vt - 40.
        return 4.0/(1 + np.exp(-0.2*v1))

    def alpha_n(x):
        v1 = x - Vt - 15.
        return 0.032*efun(-0.2*v1)/0.2

    def beta_n(x):
        v1 = x - Vt - 10.
        return 0.5*np.exp(-v1/40)

    # steady-states and time constants
    def tau_n(x):
         return 1/(alpha_n(x) + beta_n(x))
    def n_inf(x):
        return alpha_n(x)/(alpha_n(x) + beta_n(x))
    def tau_m(x):
        return 1/(alpha_m(x) + beta_m(x))
    def m_inf(x):
        return alpha_m(x)/(alpha_m(x) + beta_m(x))
    def tau_h(x):
        return 1/(alpha_h(x) + beta_h(x))
    def h_inf(x):
        return alpha_h(x)/(alpha_h(x) + beta_h(x))

    # slow non-inactivating K+
    def p_inf(x):
        v1 = x + 35.
        return 1.0/(1. + np.exp(-0.1*v1))

    def tau_p(x):
        v1 = x + 35.
        return tau_max/(3.3*np.exp(0.05*v1) + np.exp(-0.05*v1))


    ####################################
    # simulation from initial point
    V = np.zeros_like(t) # voltage
    n = np.zeros_like(t)
    m = np.zeros_like(t)
    h = np.zeros_like(t)
    p = np.zeros_like(t)

    V[0] = float(V0)
    n[0] = n_inf(V[0])
    m[0] = m_inf(V[0])
    h[0] = h_inf(V[0])
    p[0] = p_inf(V[0])

    for i in range(1, t.shape[0]):
        tau_V_inv = ( (m[i-1]**3)*gbar_Na*h[i-1]+(n[i-1]**4)*gbar_K+g_l+gbar_M*p[i-1] )/C
        V_inf = ( (m[i-1]**3)*gbar_Na*h[i-1]*E_Na+(n[i-1]**4)*gbar_K*E_K+g_l*E_leak+gbar_M*p[i-1]*E_K
                +I[i-1]+nois_fact*np.random.randn()/(tstep**0.5) )/(tau_V_inv*C)
        V[i] = V_inf + (V[i-1]-V_inf)*np.exp(-tstep*tau_V_inv)
        n[i] = n_inf(V[i])+(n[i-1]-n_inf(V[i]))*np.exp(-tstep/tau_n(V[i]))
        m[i] = m_inf(V[i])+(m[i-1]-m_inf(V[i]))*np.exp(-tstep/tau_m(V[i]))
        h[i] = h_inf(V[i])+(h[i-1]-h_inf(V[i]))*np.exp(-tstep/tau_h(V[i]))
        p[i] = p_inf(V[i])+(p[i-1]-p_inf(V[i]))*np.exp(-tstep/tau_p(V[i]))

    return np.expand_dims(V, -1)

# An example amortized model comparison

## Prior predictive checks

In [None]:
# Add stuff here

## Train an amortized estimator

In [None]:
class SequenceNet(tf.keras.Model):
    
    def __init__(self):
        """
        Creates a custom summary network, a combination of 1D conv and LSTM.
        """
        super(SequenceNet, self).__init__()
        
        self.conv_part = tf.keras.Sequential([
            tf.keras.layers.Conv1D(64, 3, 3, activation='elu'),
            tf.keras.layers.Conv1D(64, 3, 3, activation='elu'),
            tf.keras.layers.Conv1D(64, 3, 3, activation='elu'),
            tf.keras.layers.GlobalAveragePooling1D()
        ])
        
        self.lstm_part = Sequential(
            [LSTM(32, return_sequences=True), 
             LSTM(64)
            ])
        
    def call(self, x):
        """Performs a forward pass."""
        
        conv_out = self.conv_part(x)
        lstm_out = self.lstm_part(x)
        out = tf.concat((conv_out, lstm_out), axis=-1)
        return out

In [None]:
summary_net = SequenceNet()

evidential_meta = {
    'n_models': 3,
    'out_activation': 'softplus',
    'n_dense': 3,
    'dense_args': {'kernel_initializer': 'glorot_uniform', 'activation': 'relu', 'units': 128}
}

In [None]:
evidential_net = EvidentialNetwork(evidential_meta, summary_net)

In [None]:
priors = [model1_params_prior, model2_params_prior, model3_params_prior]
simulators = [forward_model1, forward_model2, forward_model3]

In [None]:
trainer = MultiModelTrainer(
    network=evidential_net, 
    model_prior=model_prior, 
    priors=priors, 
    simulators=simulators, 
    loss=partial(log_loss, lambd=0)
)

### Online training
Just a fast demo training.

In [28]:
%%time
losses = trainer.train_online(epochs=10, iterations_per_epoch=500, batch_size=32, n_obs=300)

Training epoch 1:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 2:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 3:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 4:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 5:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 6:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 7:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 8:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 9:   0%|          | 0/500 [00:00<?, ?it/s]

Training epoch 10:   0%|          | 0/500 [00:00<?, ?it/s]

Wall time: 11min 44s


### Offline training

In [30]:
# TODO

### Round-based training

In [29]:
# TODO

## Performance and calibration checks
Calibration