A Python library for estimating time-varying natural parameters of spike-train interactions using EM-based inference in a state-space log-linear model.
# Using conda
conda env create -f environment.yml
conda activate ssll
# Or using pip
pip install -r requirements.txtDependencies: numpy, scipy, matplotlib, tqdm
import numpy
import __init__ as ssll
import synthesis
import transforms
# Generate synthetic data (3 neurons, pairwise interactions)
N, O, T, R = 3, 2, 50, 200
transforms.initialise(N, O)
theta = synthesis.generate_thetas(N, O, T, seed=42)
p = numpy.array([transforms.compute_p(theta[t]) for t in range(T)])
spikes = synthesis.generate_spikes(p, R, seed=42)
# Run EM inference
emd = ssll.run(spikes, order=2, window=1, param_est='exact', param_est_eta='exact')Spike patterns are modelled by an exponential-family distribution over binary vectors
where order parameter.
Example (N=3, O=2): The model has D = 6 natural parameters — three first-order (
- O=1: Independent model — each neuron fires independently (no interactions).
- O=2: Pairwise Ising model — captures pairwise correlations (the most common choice).
- O=3: Adds triplet interactions for higher-order correlations beyond pairwise.
The sufficient statistics
The natural parameters evolve over time as a state-space model with a linear state equation and a log-linear observation model.
State equation:
where
Observation equation:
where
Autoregressive parameter F is a D×D matrix controlling how the parameters at time
-
Default (
$F = I$ ): The identity matrix gives a random-walk model:$\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} + \boldsymbol{\xi}_t$ . Each parameter drifts freely from its previous value. - General F: An autoregressive matrix that can capture mean-reverting dynamics, coupling between parameters, or other structured temporal dependencies.
- Set the initial value of F via the
state_arparameter. Whenstate_aris provided, F is optimised during the M-step. Whenstate_ar=None(default), F stays fixed at identity.
State noise covariance Q controls the expected magnitude of parameter changes between timesteps. The state_cov parameter sets the initial value of Q in one of four forms:
- Scalar (e.g.
0.01): Q = 0.01 × I — isotropic noise, all D parameters share one variance. Updated as a single scalar in the M-step. Simplest and usually sufficient. - Vector (1D array, length D): Q = diag(state_cov) — each parameter has its own variance. Updated element-wise in the M-step.
- Matrix (D×D array): Q = state_cov — full covariance, captures correlations between parameter changes. Updated as a full matrix in the M-step. Expensive (D² parameters).
- List of 2 values (e.g.
[0.01, 0.001]): Q = diag(λ1,...,λ1, λ2,...,λ2) — separate variances for first-order parameters (rates) and higher-order parameters (interactions). λ1 and λ2 are updated separately. Use this when rates and interactions evolve at different timescales.
The EM algorithm alternates between:
- E-step: Recursive Bayesian filter (forward) and smoother (backward) with Laplace approximation at each timestep. The MAP estimate is found via Newton-Raphson (
nr), conjugate gradient (cg, default), or BFGS (bf). - M-step: Optimize the noise covariance Q, and (optionally) the autoregressive parameter F and input gain matrix G.
The model supports stationary (time-independent) analysis by setting T=1. With a single time bin, the state-space machinery reduces to Bayesian inference of a static parameter:
- The initial distribution
$\boldsymbol{\theta}_1 \sim \mathcal{N}(\mu, \Sigma)$ serves as the prior, with$\mu$ =theta_oand$\Sigma$ =sigma_o$\times I$ . - The E-step computes the MAP estimate
$\hat{\boldsymbol{\theta}}$ balancing the prior and the observation likelihood. No forward-backward recursion is needed. - The M-step updates only the prior mean:
$\mu \leftarrow \hat{\boldsymbol{\theta}}$ . The state noise Q and autoregressive parameter F are not updated (there are no state transitions to estimate them from). The prior covariance$\Sigma$ remains fixed.
The EM iterations converge when
Use stationary=True to fit a stationary model. This automatically pools all T×R observations into a single time step and sets Q=0:
emd = ssll.run(spikes, order=2, param_est='exact', param_est_eta='exact', stationary=True)For large N where exact 2^N computation is infeasible:
- Pseudolikelihood (
param_est='pseudo'): Replaces the full likelihood with a product of conditional likelihoods. Scales to N ~ 60 neurons. - TAP mean-field (
param_est_eta='mf'): Second-order mean-field (Thouless-Anderson-Palmer) approximation for expectation parameters and the log partition function. - Bethe approximation: Based on the Bethe free energy, solved via:
- Belief propagation (
'bethe_BP'): Iterative message passing. - CCCP (
'bethe_CCCP'): Concave-convex procedure, guaranteed convergence. - Hybrid (
'bethe_hybrid'): Tries BP first, falls back to CCCP.
- Belief propagation (
When to use which: Use exact methods for N <= 12. For larger networks, use pseudo + mf for speed, or pseudo + bethe_hybrid for better accuracy.
After fitting, the model provides time-resolved thermodynamic quantities that characterise the collective state of the population (Donner et al. 2017). These are computed by energies.py and stored in the EMData container:
-
Log partition function
$\psi(\boldsymbol{\theta})$ : Normalisation constant of the log-linear model. Computed exactly for small N, via the Ogata-Tanemura estimator for large N, or via TAP/Bethe approximations. Stored inemd.psi(shape: T×1). -
Entropy: Measures the variability of population spike patterns.
emd.S1 — entropy of the independent (O=1) model. emd.S2 — entropy of the fitted model. emd.S_ratio = (S1 - S2) / S1 — fractional entropy reduction due to interactions.
-
Internal energy: Expected value of the energy function
$E(\mathbf{x}) = -\boldsymbol{\theta}^\top \mathbf{F}(\mathbf{x})$ .
emd.U1 — internal energy of the independent model. emd.U2 — internal energy of the fitted model.
-
Population spike rate: The first-order expectation parameters
emd.eta_s[:, :N]give the marginal firing probability of each neuron at each timestep. -
Silence probability: The probability that no neuron fires:
$p(\mathbf{x}=\mathbf{0}) = \exp(-\psi(\boldsymbol{\theta}))$ , computable fromemd.psi.
Note: Heat capacity (Donner 2017, Eq. 33) is not yet implemented — it requires an augmented partition function with a temperature parameter beta.
Main entry point. Returns an EMData container with smoothed posterior estimates.
| Parameter | Type | Default | Description |
|---|---|---|---|
spikes |
ndarray (T, R, N) | required | Binary spike data |
order |
int | 2 | Interaction order (1=rates, 2=pairwise, 3=triplet) |
window |
int | 1 | Bin width in ms |
map_function |
str | 'cg' |
MAP optimizer: 'nr', 'cg', 'bf' |
param_est |
str | 'exact' |
'exact' or 'pseudo' |
param_est_eta |
str | 'exact' |
'exact', 'mf', 'bethe_BP', 'bethe_CCCP', 'bethe_hybrid' |
state_cov |
float/array/list | 0.01 | Initial noise covariance Q: scalar (isotropic), 1D array (diagonal), D×D array (full), or list of 2 values [λ1, λ2] (separate rates/interactions) |
state_ar |
ndarray | None | Autoregressive parameter F (DxD); None = identity |
max_iter |
int | 100 | Maximum EM iterations |
theta_o |
float/array | 0 | Prior mean |
sigma_o |
float | 0.1 | Prior covariance scaling |
mstep |
bool | True | Whether to run M-step |
stationary |
bool | False | Pool all T×R observations into one time step (Q=0) for stationary analysis |
u |
ndarray (T, n_u) | None | Exogenous input; when provided, adds U·u_t to the state equation with U learned via M-step |
Returns: container.EMData object with fields:
theta_s— smoothed natural parameters (T×D)sigma_s— smoothed covariances (T×D×D)eta_s— expectation parameters (T×D)mll— log marginal likelihood history (per EM iteration)psi— log partition function (T×1)S1,S2— entropy of independent and fitted models (T×1)S_ratio— fractional entropy reduction (S1−S2)/S1 (T×1)U1,U2— internal energy of independent and fitted models (T×1)U— learned input gain matrix (D×n_u), or None if no exogenous input
python example_exact.py # Exact inference (3 neurons, 2nd-order)
python example_approx.py # Approximate inference (20 neurons, pseudo-likelihood + TAP/Bethe)The approximate inference path (param_est='pseudo' + param_est_eta='mf') is optimized for large networks. Key optimizations include sparse matrix operations (CSR format), vectorized gradient computation via precomputed stacked matrices, direct Fx_s_t difference computation (only iterates over subsets containing the target neuron), and precomputed subset membership lookups. For N=20 pairwise with T=50, R=100, a full EM run completes in ~1.5s on a modern CPU.
-
Shimazaki H, Amari S, Brown EN, Gruen S (2012). State-space analysis of time-varying higher-order spike correlation for multiple neural spike train data. PLoS Computational Biology, 8(3): e1002385.
-
Donner C, Obermayer K, Shimazaki H (2017). Approximate inference for time-varying interactions and macroscopic dynamics of neural populations. PLoS Computational Biology, 13(1): e1005309.
GPL-3.0. See LICENSE.
- Hideaki Shimazaki (h.shimazaki@kyoto-u.ac.jp)
- Christian Donner (christian.donner@bccn-berlin.de)
- Thomas Sharp (original code in Python)