# Inference on Hodgkin-Huxley model: tutorial

In this tutorial, we use `sbi` to do inference on a Hodgkin-Huxley model (Hodgkin and Huxley, 1952) with two parameters ($\bar g_{Na}$,$\bar g_K$), given a current-clamp recording (synthetically generated).

First we are going to import basic packages.

In [None]:
# GENERAL
import inspect
from typing import List, Dict, Optional, Union, Callable, Tuple
from warnings import warn

# MATH
import numpy as np
import torch

# VISUALIZATION
import matplotlib as mpl
import matplotlib.pyplot as plt
from cycler import cycler

# sbi
import sbi
import sbi.utils as utils
import torch
from sbi.inference.snpe.snpe_c import SnpeC

%load_ext autoreload
%autoreload 2

## Different required components

Before running inference, let us define the different required components:

1. observed data
2. model
3. summary statistics
4. prior over model parameters

## Observed data

Let us assume we current-clamped a neuron and recorded the following voltage trace:


<img src="https://raw.githubusercontent.com/mackelab/delfi/master/docs/docs/tutorials/observed_voltage_trace.png" width="480">
<br>

In fact, this voltage trace was not measured experimentally but synthetically generated by simulating a Hodgkin-Huxley model with particular parameters ($\bar g_{Na}$,$\bar g_K$). We will come back to this point later in the tutorial.

## Model

We would like to infer the posterior over the two parameters ($\bar g_{Na}$,$\bar g_K$) of a Hodgkin-Huxley model, given the observed electrophysiological recording above. The model has channel kinetics as in Pospischil et al. 2008, and is defined by the following set of differential equations (parameters of interest highlighted in orange):

$$
\scriptsize
\begin{align}
C_m\frac{dV}{dt}&=g_1\left(E_1-V\right)+
                    \color{orange}{\bar{g}_{Na}}m^3h\left(E_{Na}-V\right)+
                    \color{orange}{\bar{g}_{K}}n^4\left(E_K-V\right)+
                    \bar{g}_Mp\left(E_K-V\right)+
                    I_{inj}+
                    \sigma\eta\left(t\right)\\
\frac{dq}{dt}&=\frac{q_\infty\left(V\right)-q}{\tau_q\left(V\right)},\;q\in\{m,h,n,p\}
\end{align}
$$

where $V$ is the membrane potential, $C_m$ is the membrane capacitance, $g_{\text{l}}$ is the leak conductance, $E_{\text{l}}$ is the membrane reversal potential, $\bar{g}_c$ is the density of channels of type $c$ ($\text{Na}^+$, $\text{K}^+$, M), $E_c$ is the reversal potential of $c$, ($m$, $h$, $n$, $p$) are the respective channel gating kinetic variables, and $\sigma \eta(t)$ is the intrinsic neural noise. The right hand side of the voltage dynamics is composed of a leak current, a voltage-dependent $\text{Na}^+$ current, a delayed-rectifier $\text{K}^+$ current, a slow voltage-dependent $\text{K}^+$ current responsible for spike-frequency adaptation, and an injected current $I_{\text{inj}}$. Channel gating variables $q$ have dynamics fully characterized by the neuron membrane potential $V$, given the respective steady-state $q_{\infty}(V)$ and time constant $\tau_{q}(V)$ (details in Pospischil et al. 2008).

The input current $I_{\text{inj}}$ is defined as

In [None]:
def syn_current(duration=120, dt=0.01, t_on = 10,
                curr_level = 5e-4, seed=None):
    t_offset = 0.
    duration = duration
    t_off = duration - t_on
    t = np.arange(0, duration+dt, dt)

    # external current
    A_soma = np.pi*((70.*1e-4)**2)  # cm2
    I = np.zeros_like(t)
    I[int(np.round(t_on/dt)):int(np.round(t_off/dt))] = curr_level/A_soma # muA/cm2

    return I, t_on, t_off, dt, t, A_soma

I, t_on, t_off, dt, t, A_soma = syn_current()

The Hodgkin-Huxley simulator is given by:

In [None]:
def HHsimulator(V0, params, dt, t, I, seed=None):
    """Simulates the Hodgkin-Huxley model for a specified time duration and current

        Parameters
        ---------- 
        V0 : float
            Voltage at first time step
        params : np.array, 1d of length dim_param
            Parameter vector
        dt : float
            Timestep
        t : array
            Numpy array with the time steps
        I : array
            Numpy array with the input current
        seed : int
        """

    gbar_Na = params[0,0]  # mS/cm2
    gbar_Na.astype(float)
    gbar_K = params[0,1]   # mS/cm2
    gbar_K.astype(float)

    # fixed parameters
    g_leak = 0.1    # mS/cm2
    gbar_M = 0.07   # mS/cm2
    tau_max = 6e2   # ms
    Vt = -60.       # mV
    nois_fact = 0.1 # uA/cm2
    E_leak = -70.   # mV
    C = 1.          # uF/cm2
    E_Na = 53       # mV
    E_K = -107      # mV

    tstep = float(dt)

    if seed is not None:
        rng = np.random.RandomState(seed=seed)
    else:
        rng = np.random.RandomState()

    ####################################
    # kinetics
    def efun(z):
        if np.abs(z) < 1e-4:
            return 1 - z/2
        else:
            return z / (np.exp(z) - 1)

    def alpha_m(x):
        v1 = x - Vt - 13.
        return 0.32*efun(-0.25*v1)/0.25

    def beta_m(x):
        v1 = x - Vt - 40
        return 0.28*efun(0.2*v1)/0.2

    def alpha_h(x):
        v1 = x - Vt - 17.
        return 0.128*np.exp(-v1/18.)

    def beta_h(x):
        v1 = x - Vt - 40.
        return 4.0/(1 + np.exp(-0.2*v1))

    def alpha_n(x):
        v1 = x - Vt - 15.
        return 0.032*efun(-0.2*v1)/0.2

    def beta_n(x):
        v1 = x - Vt - 10.
        return 0.5*np.exp(-v1/40)

    # steady-states and time constants
    def tau_n(x):
         return 1/(alpha_n(x) + beta_n(x))
    def n_inf(x):
        return alpha_n(x)/(alpha_n(x) + beta_n(x))
    def tau_m(x):
        return 1/(alpha_m(x) + beta_m(x))
    def m_inf(x):
        return alpha_m(x)/(alpha_m(x) + beta_m(x))
    def tau_h(x):
        return 1/(alpha_h(x) + beta_h(x))
    def h_inf(x):
        return alpha_h(x)/(alpha_h(x) + beta_h(x))

    # slow non-inactivating K+
    def p_inf(x):
        v1 = x + 35.
        return 1.0/(1. + np.exp(-0.1*v1))

    def tau_p(x):
        v1 = x + 35.
        return tau_max/(3.3*np.exp(0.05*v1) + np.exp(-0.05*v1))


    ####################################
    # simulation from initial point
    V = np.zeros_like(t) # voltage
    n = np.zeros_like(t)
    m = np.zeros_like(t)
    h = np.zeros_like(t)
    p = np.zeros_like(t)

    V[0] = float(V0)
    n[0] = n_inf(V[0])
    m[0] = m_inf(V[0])
    h[0] = h_inf(V[0])
    p[0] = p_inf(V[0])

    for i in range(1, t.shape[0]):
        tau_V_inv = ( (m[i-1]**3)*gbar_Na*h[i-1]+(n[i-1]**4)*gbar_K+g_leak+gbar_M*p[i-1] )/C
        V_inf = ( (m[i-1]**3)*gbar_Na*h[i-1]*E_Na+(n[i-1]**4)*gbar_K*E_K+g_leak*E_leak+gbar_M*p[i-1]*E_K
                +I[i-1]+nois_fact*rng.randn()/(tstep**0.5) )/(tau_V_inv*C)
        V[i] = V_inf + (V[i-1]-V_inf)*np.exp(-tstep*tau_V_inv)
        n[i] = n_inf(V[i])+(n[i-1]-n_inf(V[i]))*np.exp(-tstep/tau_n(V[i]))
        m[i] = m_inf(V[i])+(m[i-1]-m_inf(V[i]))*np.exp(-tstep/tau_m(V[i]))
        h[i] = h_inf(V[i])+(h[i-1]-h_inf(V[i]))*np.exp(-tstep/tau_h(V[i]))
        p[i] = p_inf(V[i])+(p[i-1]-p_inf(V[i]))*np.exp(-tstep/tau_p(V[i]))

    return np.array(V).reshape(-1,1)

Putting the input current and the simulator together:

In [None]:
def run_HH_model(params):

    # input current, time step
    I, t_on, t_off, dt, t, A_soma = syn_current()

    t = np.arange(0, len(I), 1)*dt

    # initial voltage
    V0 = -70
    
    params = np.asarray(params)
    
    assert params.ndim == 1, 'params.ndim must be 1'

    states = HHsimulator(V0, params.reshape(1, -1), dt, t, I, seed=0)

    return {'data': states.reshape(-1),
            'time': t,
            'dt': dt,
            'I': I.reshape(-1)}

To get an idea of the output of the Hodgkin-Huxley model, let us generate some voltage traces for different parameters ($\bar g_{Na}$,$\bar g_K$), given the input current $I_{\text{inj}}$:

In [None]:
params = np.array([[50., 1.],[4., 1.5],[20., 15.]])

num_samples = len(params[:,0])
sim_samples = np.zeros((num_samples,len(I)))
for i in range(num_samples):
    sim_samples[i,:] = run_HH_model(params=params[i,:])['data']

In [None]:
# colors for traces
col_min = 2
num_colors = num_samples+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

# plotting
%matplotlib inline

fig = plt.figure(figsize=(7,5))
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
for i in range(num_samples):
    plt.plot(t,sim_samples[i,:],color=col1[i],lw=2)
plt.ylabel('voltage (mV)')
ax.set_xticks([])
ax.set_yticks([-80, -20, 40])

ax = plt.subplot(gs[1])
plt.plot(t,I*A_soma*1e3,'k', lw=2)
plt.xlabel('time (ms)')
plt.ylabel('input (nA)')

ax.set_xticks([0, max(t)/2, max(t)])
ax.set_yticks([0, 1.1*np.max(I*A_soma*1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.2f'))

As can be seen, the voltage traces can be quite diverse for different parameter values. How can we infer the parameters ($\bar g_{Na}$,$\bar g_K$) underlying the particular observed voltage trace? `sbi` will allow us to solve this problem.

`sbi` takes any function as simulator. Thus, `sbi` also has the flexibility to use simulators that utilize external packages, e.g., Brian (http://briansimulator.org/), nest (https://www.nest-simulator.org/), or NEURON (https://neuron.yale.edu/neuron/). External simulators do not even need to be Python-based as long as they store simulation outputs in a format that can be read from Python. All that might be necessary to wrap your external simulator of choice into a python callable that takes batches of parameters as input and gives batches of simulation results as outputs.

### Summary statistics
Often, we are not interested in matching the exact trace, but only certain features thereof. For the Hodgkin Huxley model, the summary features are number of spikes, mean resting potential, standard deviation of the resting potential, and the first 4 voltage moments, mean, standard deviation, skewness and kurtosis. In the function `calculate_summary_statistics()` below, we compute these statistics from the output of the Hodgkin Huxley simulator. 

In [None]:
from scipy import stats as spstats

def calculate_summary_statistics(x):
    """Calculate summary statistics

    Parameters
    ----------
    repetition_list : list of dictionaries, one per repetition
        data list, returned by `gen` method of Simulator instance

    Returns
    -------
    np.array, 2d with n_reps x n_summary
    """
    I, t_on, t_off, dt, t, A_soma = syn_current()
    
    n_mom = 4
    n_summary = 7
    
    n_summary = np.minimum(n_summary,n_mom + 3)

    stats = []
    
    N = x['data'].shape[0]
    t = x['time']
    dt = x['dt']

    # initialise array of spike counts
    v = np.array(x['data'])

    # put everything to -10 that is below -10 or has negative slope
    ind = np.where(v < -10)
    v[ind] = -10
    ind = np.where(np.diff(v) < 0)
    v[ind] = -10

    # remaining negative slopes are at spike peaks
    ind = np.where(np.diff(v) < 0)
    spike_times = np.array(t)[ind]
    spike_times_stim = spike_times[(spike_times > t_on) & (spike_times < t_off)]

    # number of spikes
    if spike_times_stim.shape[0] > 0:
        spike_times_stim = spike_times_stim[np.append(1, np.diff(spike_times_stim))>0.5]

    # resting potential and std
    rest_pot = np.mean(x['data'][t<t_on])
    rest_pot_std = np.std(x['data'][int(.9*t_on/dt):int(t_on/dt)])

    # moments
    std_pw = np.power(np.std(x['data'][(t > t_on) & (t < t_off)]),
                      np.linspace(3,n_mom,n_mom-2))
    std_pw = np.concatenate((np.ones(1),std_pw))
    moments = spstats.moment(x['data'][(t > t_on) & (t < t_off)],
                             np.linspace(2,n_mom,n_mom-1))/std_pw

    # concatenation of summary statistics
    sum_stats_vec = np.concatenate((
            np.array([spike_times_stim.shape[0]]),
            np.array([rest_pot,rest_pot_std,np.mean(x['data'][(t > t_on) & (t < t_off)])]),
            moments
        ))
    sum_stats_vec = sum_stats_vec[0:n_summary]

    return sum_stats_vec

Lastly, we define a function that performs all of the above steps at once. The function `simulation_wrapper` takes in conductance values and then first runs the Hodgkin Huxley model and then returns the summary statistics as torch.Tensor. This function is the only thing you will need to provide for SNPE - however the calculation of summary statistics from parameters is implemented within this function does not matter to SNPE.

In [None]:
def simulation_wrapper(params):
    """
    Takes in conductance values and then first runs the Hodgkin Huxley model and then returns the summary statistics as torch.Tensor
    """
    obs = run_HH_model(params)
    summstats = torch.as_tensor(calculate_summary_statistics(obs))
    return summstats

## Prior over model parameters

Now that we have the model and simulator class, we need to define a function with the prior over the model parameters ($\bar g_{Na}$,$\bar g_K$), which in this case is chosen to be a Uniform distribution:

In [None]:
seed_p = 2
prior_min = [.5,1e-4]
prior_max = [80.,15.]
prior = utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max))

## Inference
Now that we have all the required components, we can run inference with SNPE. We start by importing our SNPE object of choice.

Next, for our purposes, we are going to simulate a ground truth observation.

### Inspect observed data

In [None]:
# true parameters and respective labels
true_params = np.array([50., 5.])       
labels_params = [r'$g_{Na}$', r'$g_{K}$']

# observed data: simulation given true parameters
observation_summary_statistics = simulation_wrapper(true_params)

In [None]:
observation = run_HH_model(true_params)

As we had already shown above, the observed voltage traces look as follows:

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline

fig = plt.figure(figsize=(7,5))
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
plt.plot(observation['time'],observation['data'])
plt.ylabel('voltage (mV)')
plt.title('observed data')
ax.set_xticks([])
ax.set_yticks([-80, -20, 40])

ax = plt.subplot(gs[1])
plt.plot(observation['time'],I*A_soma*1e3,'k', lw=2)
plt.xlabel('time (ms)')
plt.ylabel('input (nA)')

ax.set_xticks([0, max(observation['time'])/2, max(observation['time'])])
ax.set_yticks([0, 1.1*np.max(I*A_soma*1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.2f'))

We now want to use SNPE to identify parameters whose activity matches this trace. To do so, we instantiate the SNPE object...

In [None]:
neural_net = utils.posterior_nn(model="maf", prior=prior, x_o=observation_summary_statistics)

snpe_common_args = dict(
    simulator=simulation_wrapper,
    x_o=observation_summary_statistics,
    density_estimator=neural_net,
    prior=prior,
    simulation_batch_size=1,
)

... and run the inference.

In [None]:
infer = SnpeC(sample_with_mcmc=False, num_pilot_samples=2000, **snpe_common_args)

# Run inference.
num_rounds, num_simulations_per_round = 1, 2000
posterior = infer(
    num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round,
)

Note that we chose the algorithm `SNPE-C`, but other algorithms are available within `sbi` (see http://www.mackelab.org/sbi). 

## Analysis of the results

After running the inference algorithm, let us inspect and analyse the results. We can start by checking whether the algorithm training has converged by inspecting the respective loss function:

In [None]:
samples = posterior.sample(10000)

In [None]:
fig, axes = utils.samples_nd(utils.tensor2numpy(samples),
                       limits=np.asarray([prior_min, prior_max]).T,
                       ticks=np.asarray([prior_min, prior_max]).T,
                       fig_size=(5,5),
                       diag='kde',
                       upper='kde',
                       hist_diag={'bins': 50},
                       hist_offdiag={'bins': 50},
                       kde_diag={'bins': 50, 'color': 'g'},
                       kde_offdiag={'bins': 50},
                       points=[true_params],
                       points_offdiag={'markersize': 5},
                       points_colors='g',
                       title='');

As can be seen, the inferred posterior contains the ground-truth parameters (green) in a high-probability region. Now, let us sample parameters from the posterior distribution, simulate the Hodgkin-Huxley model for each of those samples and compare the simulations with the observed data:

In [None]:
y_obs = observation['data']
t = observation['time']
duration = np.max(t)

num_samp = 2

# sample from posterior
x_samp = posterior.sample(num_samples=num_samp)
x_samp = x_samp.detach().numpy()

In [None]:
fig = plt.figure(figsize=(7,5))

# reject samples for which prior is zero
ind = (x_samp > prior_min) & (x_samp < prior_max)
params = x_samp[np.prod(ind,axis=1)==1]

num_samp = len(params[:,0])

# simulate and plot samples
V = np.zeros((len(t),num_samp))
for i in range(num_samp):
    x = run_HH_model(params[i,:])
    V[:,i] = x['data']
    plt.plot(t, V[:, i], color = col['CONSISTENT'+str(i+1)], lw=2, label='sample '+str(num_samp-i))

# plot observation
plt.plot(t, y_obs, '--',lw=2, label='observation')
plt.xlabel('time (ms)')
plt.ylabel('voltage (mV)')

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), loc='upper right')

ax.set_xticks([0, duration/2, duration])
ax.set_yticks([-80, -20, 40]);

As can be seen, the samples from the inferred posterior lead to simulations that closely resemble the observed data, confirming that `SNPE-C` did a good job at capturing the observed data.

## References


A. L. Hodgkin and A. F. Huxley. A quantitative description of membrane current and its application to conduction and excitation in nerve. The Journal of Physiology, 117(4):500–544, 1952.

M. Pospischil, M. Toledo-Rodriguez, C. Monier, Z. Piwkowska, T. Bal, Y. Frégnac, H. Markram, and A. Destexhe. Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biological Cybernetics, 99(4-5), 2008.

This example, including the notebook itself, as well as the modules HH_simulator.py, HH_stimuli.py, and HH_statistics.py, are based on the quickstart [Hodgkin Huxley example in the sbi sucessor delfi](https://github.com/mackelab/delfi/blob/master/docs/docs/tutorials/quickstart.ipynb).