## Hodgkin-Huxley + Bayesian Inference with Synthetic Data

In [1]:
import torch
import numpy as np
import sys

sys.path.append("pyHH/src")

from pyhh import HHModel, Simulation

#  Hodgkin-Huxley Model Simulation
def run_HH_model(parameters):
    """Simulate HH model and return membrane potential trace."""
    V_rest, stim_amplitude = parameters
    model = HHModel()
    model.V_rest = V_rest

    # Create stimulus waveform (on for a short duration)
    stim = np.zeros(20000)
    stim[7000:13000] = stim_amplitude

    # Run the simulation
    sim = Simulation(model)
    sim.Run(stimulusWaveform=stim, stepSizeMs=0.01)
    
    return sim.Vm # Membrane potential trace



# Calculate Summary Statistics from Membrane Potential
def calculate_summary_statistics(membrane_potential):
    """Convert voltage trace to summary statistics."""
    return torch.tensor([
        np.max(membrane_potential),    # Peak
        np.mean(membrane_potential),   # Average
        np.std(membrane_potential)     # Variability
    ], dtype=torch.float32)


# Wrapping Function for Simulation Compatibility with Torch
def simulation_wrapper_torch(params):
    if isinstance(params, torch.Tensor):
        params = params.detach().cpu().numpy()
    result = run_HH_model(params)
    summary = calculate_summary_statistics(result)
    return summary

## üé≤ Define the Prior

A uniform prior is defined for the parameters (`V_rest` and `stim_amplitude`), where each parameter is constrained between specified min and max values.



In [2]:
from sbi import utils as utils
prior = utils.BoxUniform(
    low=torch.tensor([40.0, 3.0]),  # Min V_rest, stim
    high=torch.tensor([60.0, 7.0])  # Max V_rest, stim
)

 ## üîÅ   Inference Setup and Simulations to Train Posterior Model

 The inference process uses the simulation-based method (SNPE) to learn the posterior distribution over the parameters from simulated data. We generate 500 samples of parameters, run simulations, and use the results to train the posterior.

In [3]:
from sbi.inference import SNPE
from sbi.utils import BoxUniform

inference = SNPE(prior=prior)

theta = []
x = []

# Sample 500 sets of parameters and run simulations
for _ in range(500):
    param = prior.sample((1,)) 
    sim_result = simulation_wrapper_torch(param.squeeze(0))
    theta.append(param.squeeze(0))
    x.append(sim_result)

# Convert results into tensors
theta = torch.stack(theta)
x = torch.stack(x)

inference = inference.append_simulations(theta, x)
density_estimator = inference.train()
posterior = inference.build_posterior(density_estimator)

simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 20000 time points...
simulation complete
simulating 2

## üéØ Inference on Synthetic Data

In [4]:
import matplotlib.pyplot as plt
import seaborn as sns


def infer_and_plot_interactive(V_rest=50.0, stim_amp=5.0):
    """Simulate synthetic data, infer parameters, and plot the posterior."""
    # Generate synthetic data with the given parameters
    true_params = torch.tensor([V_rest, stim_amp])
    observed_voltage = run_HH_model(true_params.numpy())  # Generate synthetic voltage trace
    observed_stats = calculate_summary_statistics(observed_voltage)  # Compute summary statistics

    # Sample from the posterior using the observed stats
    samples = posterior.sample((1000,), x=observed_stats)

    # Plot the posterior distribution
    samples_np = samples.numpy()
    sns.kdeplot(x=samples_np[:, 0], y=samples_np[:, 1], fill=True, cmap='Blues')

    # Plot the true parameters for reference
    plt.plot(true_params[0], true_params[1], 'r*', label='True Parameters', markersize=12)
    plt.xlabel("V_rest")
    plt.ylabel("Stim Amplitude")
    plt.title("Posterior Distribution")
    plt.legend()
    plt.show()

## üéõÔ∏è Try It Yourself!

How it works:
- The sliders allow to change the values of `V_rest` (resting membrane potential) and `stim_amplitude` (stimulus amplitude).
- It then generates synthetic membrane voltage data based on those parameters and compute summary statistics.
- Using these statistics, Bayesian inference is performed to sample the posterior distribution of the parameters.
- The true parameters are displayed on the plot as a red star for comparison.

In [5]:
from ipywidgets import FloatSlider, interact

interact(
    infer_and_plot_interactive,
    V_rest=FloatSlider(min=40, max=60, step=1, value=50),    # Control resting potential
    stim_amp=FloatSlider(min=3, max=7, step=0.5, value=5)     # Control stimulus amplitude
);

interactive(children=(FloatSlider(value=50.0, description='V_rest', max=60.0, min=40.0, step=1.0), FloatSlider‚Ä¶