## Why Hodgkin–Huxley and Bayesian Inference?

This notebook introduces how we can use **Bayesian inference** to estimate parameters in a **Hodgkin–Huxley (HH)** neuron model.

The HH model describes how neurons generate electrical impulses (action potentials) using ion channels. These equations are biologically detailed and nonlinear, meaning that **small changes in parameters** (like resting potential or stimulus amplitude) can produce **very different voltage traces**.

However, in real experiments, we rarely know the exact parameter values — we only observe the **voltage recordings**. So, we turn to **Bayesian inference** to estimate the *probable values* of hidden parameters given the observed data.

We’ll use **simulation-based inference (SBI)**, which lets us:
- Simulate artificial (“synthetic”) data from our model,
- Train a neural network to approximate the posterior distribution \( p(\theta | x) \),
- And then infer parameters for new, unseen data.

This approach is especially useful when:
- The model is **complex or non-linear**,
- The **likelihood function** is hard to write down analytically,
- But we can easily **simulate** data from the model.

<br> 

---


## Simulating the Hodgkin–Huxley Model
We start by defining a function to run the Hodgkin–Huxley (HH) model.  
The HH model mathematically describes how the membrane potential of a neuron evolves over time due to ion channel dynamics.

Here, the key parameters are:

- `V_rest`: The resting membrane potential (baseline voltage)
- `stim_amplitude`: The amplitude of the external current stimulation

This function will serve as our **data generator**, it takes parameters and returns simulated voltage traces.


In [2]:
import torch
import numpy as np
import sys

sys.path.append("pyHH/src")

from pyhh import HHModel, Simulation

#  Hodgkin-Huxley Model Simulation
def run_HH_model(parameters):
    """Simulate HH model and return membrane potential trace."""
    
    # 1. Parameter Assignment and Units
    V_rest, stim_amplitude = parameters
    # V_rest: Resting potential (mV).
    # stim_amplitude: Amplitude of the applied current stimulus (μA/cm²).

    model = HHModel()
    model.V_rest = V_rest

    # 2. Stimulus Waveform Definition
    # We create a 200ms trace (20000 points @ 0.01ms step size)
    stim = np.zeros(20000)

    # Stimulus pulse applied between t=70ms and t=130ms (60ms duration)
    # This duration is chosen to reliably trigger an action potential.
    stim[7000:13000] = stim_amplitude

    # 3. Running the Simulation
    sim = Simulation(model)
    # stepSizeMs=0.01 means a 10 microsecond time step, ensuring numerical stability.
    sim.Run(stimulusWaveform=stim, stepSizeMs=0.01)
    
    return sim.Vm # Membrane potential trace



# Calculate Summary Statistics from Membrane Potential
def calculate_summary_statistics(membrane_potential):
    """
    Convert the detailed voltage trace to a low-dimensional vector of Summary Statistics.
    These statistics replace the intractable likelihood in the SNPE procedure.
    """
    return torch.tensor([
        np.max(membrane_potential),    # Peak potential (e.g., Action Potential maximum height in mV)
        np.mean(membrane_potential),   # Average potential (often near resting potential in mV)
        np.std(membrane_potential)     # Variability (proxy for the magnitude of the spike/response in mV)
    ], dtype=torch.float32)



# Wrapping Function for Simulation Compatibility with Torch
def simulation_wrapper_torch(params):
    if isinstance(params, torch.Tensor):
        params = params.detach().cpu().numpy()
    result = run_HH_model(params)
    summary = calculate_summary_statistics(result)
    return summary

## Define the Prior Distribution

In Bayesian inference, we express our **initial beliefs** about parameters before seeing any data using a *prior distribution*.

We assume each parameter (e.g., resting potential, stimulus amplitude) is uniformly distributed within a biologically reasonable range:

- $V_{\text{rest}}$ (Resting Potential, mV) in $[40, 60]$
- $\mathbf{\text{stim\_amplitude}}$ in $[3, 7]$

This means we consider all values in that range equally likely before seeing the data.



In [3]:
from sbi import utils as utils
prior = utils.BoxUniform(
    low=torch.tensor([40.0, 3.0]),  # [V_rest min, stim_amplitude min]
    high=torch.tensor([60.0, 7.0])  # [V_rest max, stim_amplitude max]
)

## Generating Synthetic Data

Now we create **synthetic data** with simulated experiments where we know the true parameters.

Why synthetic data?
- In many biological systems, it's difficult or impossible to measure the “true” parameter values directly.
- By simulating data under known parameters, we can teach our inference algorithm what different parameter combinations “look like” in terms of their outputs.

Each simulated data point consists of:
- The parameters $\left(\theta = [V_{\text{rest}}, \text{stim\_amplitude}]\right)$
- The resulting voltage trace $x$

These pairs $(\theta, x)$ form the **training dataset** for the Bayesian inference model.


In [4]:
from sbi.inference import SNPE
from sbi.utils import BoxUniform

inference = SNPE(prior=prior)

theta = []
x = []

# Sample 500 sets of parameters and run simulations
for _ in range(500):
    param = prior.sample((1,)) 
    sim_result = simulation_wrapper_torch(param.squeeze(0))
    theta.append(param.squeeze(0))
    x.append(sim_result)

# Convert results into tensors
theta = torch.stack(theta)
x = torch.stack(x)

inference = inference.append_simulations(theta, x)
density_estimator = inference.train()
posterior = inference.build_posterior(density_estimator)

simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 2

## Learning the Posterior with Simulation-Based Inference (SNPE)

We use **Sequential Neural Posterior Estimation (SNPE)**, a method that trains a neural network to approximate the *posterior distribution*  $p(\theta \mid x)$.

Here’s the intuition:
1. We simulate many examples of (parameters → data).
2. The network learns the relationship between observed data and parameters.
3. After training, we can feed it new (real or synthetic) data to estimate the most likely parameters that generated it.

This approach works even when the likelihood $p(x \mid \theta)$ is intractable — which is common in neuroscience models.

In [5]:
import matplotlib.pyplot as plt
import seaborn as sns


def infer_and_plot_interactive(V_rest=50.0, stim_amp=5.0):
    """Simulate synthetic data, infer parameters, and plot the posterior."""
    # Generate synthetic data with the given parameters
    true_params = torch.tensor([V_rest, stim_amp])
    observed_voltage = run_HH_model(true_params.numpy())  # Generate synthetic voltage trace
    observed_stats = calculate_summary_statistics(observed_voltage)  # Compute summary statistics

    # Sample from the posterior using the observed stats
    samples = posterior.sample((1000,), x=observed_stats)

    # Plot the posterior distribution
    samples_np = samples.numpy()
    sns.kdeplot(x=samples_np[:, 0], y=samples_np[:, 1], fill=True, cmap='Blues')

    # Plot the true parameters for reference
    plt.plot(true_params[0], true_params[1], 'r*', label='True Parameters', markersize=12)
    plt.xlabel("V_rest")
    plt.ylabel("Stim Amplitude")
    plt.title("Posterior Distribution")
    plt.legend()
    plt.show()

## Try It Yourself!

How it works:
- The sliders allow to change the values of `V_rest` (resting membrane potential) and `stim_amplitude` (stimulus amplitude).
- It then generates synthetic membrane voltage data based on those parameters and compute summary statistics.
- Using these statistics, Bayesian inference is performed to sample the posterior distribution of the parameters.
- The true parameters are displayed on the plot as a red star for comparison.

In [6]:
from ipywidgets import FloatSlider, interact

interact(
    infer_and_plot_interactive,
    V_rest=FloatSlider(min=40, max=60, step=1, value=50),    # Control resting potential
    stim_amp=FloatSlider(min=3, max=7, step=0.5, value=5)     # Control stimulus amplitude
);

interactive(children=(FloatSlider(value=50.0, description='V_rest', max=60.0, min=40.0, step=1.0), FloatSlider…

## Interpreting the Results
After training, we can now infer the posterior distribution of parameters given observed data.

Each posterior sample represents one plausible combination of $(\mathbf{V}_{\text{rest}}, \mathbf{\text{stim\_amplitude}})$ that could have produced the observed membrane voltage trace.

By examining the **posterior mean**, **variance**, and **shape**, we can:
- Quantify uncertainty about parameter estimates,
- Identify which parameters the model is sensitive to,
- And gain biological insight into how stimulation and resting potential interact.