Example Script for loading the sample data in 'Data/example.mat' and fitting a CBEM.

 K. W. Latimer,  F. Rieke, & J. W. Pillow (2019)
[Inferring synaptic inputs from spikes with a conductance-based neural encoding model](https://elifesciences.org/articles/47012) eLife 8 (2019): e47012.



In [None]:
import pyCBEM.RGC_CBEM as cbem
import numpy as np
import h5py
import matplotlib.pyplot as plt
import jax.scipy.optimize as jso

In [None]:
# load stimulus and spike times from the matlab file
filename = "Data/example.mat"
with h5py.File(filename, "r") as f:
    # Get the data
    stimulus  = np.array(list(f["X"]))[0];
    spkVector = np.array(list(f["Y"]))[0];
spkTimes_bins = np.where(spkVector > 0)[0];

In [None]:
# create CBEM object
myCBEM = cbem.CBEM_basic(0.1); # bin size is 0.1 ms

# set stimulus & spike times
window = range(12000,70000); # range of bins to use 
myCBEM.setObservations(stimulus, spkTimes_bins, window);

In [None]:
# fits the model

# initialize parameters
myCBEM.randomizeParameters()
B_init = myCBEM.vectorizeParameters()

# optimize
ff = lambda x : myCBEM.vectorizedPenalizedNegLogLike(x);
results = jso.minimize(ff, B_init, method="BFGS");

fun_init  = ff(B_init);
fun_final =ff(results.x);
print("initial penalized log likelihood: " + str(fun_init));
print("final penalized log likelihood: " + str(fun_final));

# set the CBEM's parameters to the fitted values
myCBEM.setParametersFromVector(results.x)

In [None]:
# plot the fitted filters
tts = np.arange(1, myCBEM.basis_conductance.shape[0]+1)*myCBEM.binSize_ms;
plt.plot(tts, np.zeros(tts.size), "k:")
plt.plot(tts, myCBEM.getConductanceFilter(0), label="k_e");
plt.plot(tts, myCBEM.getConductanceFilter(1), label="k_i");
plt.xlabel("time (ms)");
plt.ylabel("weight");
plt.title("conductance filters");
plt.legend();
plt.show();

tts = np.arange(1, myCBEM.basis_hspk.shape[0]+1)*myCBEM.binSize_ms;
plt.plot(tts, np.zeros(tts.size), "k:")
plt.plot(tts, myCBEM.getSpikeHistoryFilter(), label="h_spk");
plt.xlabel("time post-spike (ms)");
plt.ylabel("weight");
plt.legend();
plt.title("spike history filter");
plt.show();

In [None]:
# plot the fitted spike rate over the spike times
spikeRate = myCBEM.getSpikeRate()
pltLength_bins = 10000;
plt.plot(np.arange(pltLength_bins)*myCBEM.binSize_ms, myCBEM.Y[0:pltLength_bins], "k", label=["sps"], linewidth=0.1)
plt.plot(np.arange(pltLength_bins)*myCBEM.binSize_ms, spikeRate[0:pltLength_bins] / 1e3, label=["sr"])
plt.xlabel('time (ms)');
plt.ylabel('spike rate (sp/ms')
plt.show()

In [None]:
# plot the fitted conductances over the spike times
gs = myCBEM.getConductances()
pltLength_bins = 10000;
plt.plot(np.arange(pltLength_bins)*myCBEM.binSize_ms, myCBEM.Y[0:pltLength_bins]*50, "k", linewidth=0.1) # scale spike time vector so it's visible
plt.plot(np.arange(pltLength_bins)*myCBEM.binSize_ms  ,gs[0:pltLength_bins,:], label=["g_e", "g_i"])
plt.xlabel('time (ms)');
plt.ylabel('conductance')
plt.legend()
plt.show()

In [None]:
# simulate some spike trains from fit (a test stimulus could be used by calling the myCBEM.setObservations function again)
# THIS FUNCTION IS REALLY SLOW! I haven't done any profiling or changes (like trying numpy instead of Jax) to make it faster.
Y_init = np.zeros((55000,2)); # initial part of the 2 simulated spike trains (for demo purposes, this assumes no spiking in that window. I made the window big to only simulate a small segment.)
Y_all = myCBEM.simulateSpikeTrains(Y_init);
    # Y_all is vectorized spike times (T x N array of ones and zeros)
plt.plot(Y_all, linewidth=0.1);
plt.show();