In [13]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import glob
import pickle
from scipy.linalg import hankel
from scipy.stats import dirichlet
import statsmodels.api as sm
import matplotx

plt.style.use(matplotx.styles.aura["dark"])

In [2]:
def trial_xticks(ax, xlocs, yloc=-0.04):
    xlocs = np.array(xlocs)
    ax.set_xticks(xlocs)
    ax.set_xticklabels([])
    ax.tick_params("x", length=17, width=1, which="major")
    ax.set_xlabel("Time", labelpad=10)

    periods = ["S", "Cue", "Delay", "Arm", "Reward"]
    for i in range(5):
        xloc = xlocs[i] + (xlocs[i + 1] - xlocs[i]) / 2
        ax.text(
            xloc,
            yloc,
            periods[i],
            fontsize=10,
            horizontalalignment="center",
            verticalalignment="top",
            transform=ax.get_xaxis_transform(),
            rotation=0,
        )


def nll(w, X, y):
    lam = np.exp(X @ w)
    # print(lam.shape)
    return -1 * y.T @ np.log(lam) + lam.sum() - 0.5 * w.T @ w


colors = ["tab:red", "tab:blue"]

In [3]:
with open("test_data_acc_ind_492_0607.pickle", "rb") as handle:
    data = pickle.load(handle)
print(data.keys())
n_neurons = len(data["spikes"])
print(f"n_neurons: {n_neurons}")

dict_keys(['nCues_RminusL', 'currMaze', 'laserON', 'trialStart', 'trialEnd', 'keyFrames', 'time', 'cueOnset_L', 'cueOnset_R', 'choice', 'trialType', 'spikes', 'timeSqueezedFR'])
n_neurons: 324


In [9]:
# constructing design matrix with all trials and all neurons
X_all = []
y_all = []
trial_indices = np.nonzero(data["currMaze"] > 7)[0]
print(f"number of trials: {trial_indices.size}")
filt_len = 30
bin_size = 0.35

for neuron in range(n_neurons):
    X = []
    y = []

    for trial_idx in trial_indices:
        trial_start = data["trialStart"][trial_idx]
        trial_end = data["trialEnd"][trial_idx]
        trial_length = trial_end - trial_start
        spikes = data["spikes"][neuron]
        spikes = spikes[(spikes > trial_start) * (spikes < trial_end)]
        keyframes = data["keyFrames"][trial_idx]
        keyframe_times = data["time"][trial_idx][keyframes.astype(int)].tolist()
        lcue_times = data["cueOnset_L"][trial_idx]
        rcue_times = data["cueOnset_R"][trial_idx]

        bins = np.arange(0, trial_length, bin_size)
        bin_centers = np.convolve(bins, [0.5, 0.5], mode="valid")

        binned_stimr, _ = np.histogram(rcue_times, bins)
        binned_stiml, _ = np.histogram(lcue_times, bins)
        binned_spikes, _ = np.histogram(spikes - trial_start, bins)
        binned_ev = np.cumsum(binned_stimr) - np.cumsum(binned_stiml)

        padded_stimr = np.pad(binned_stimr, (filt_len - 1, 0), constant_values=(0, 0))
        X_sr = hankel(padded_stimr[: -filt_len + 1], padded_stimr[-filt_len:])
        padded_stiml = np.pad(binned_stiml, (filt_len - 1, 0), constant_values=(0, 0))
        X_sl = hankel(padded_stiml[: -filt_len + 1], padded_stiml[-filt_len:])
        padded_ev = np.pad(binned_ev, (filt_len - 1, 0), constant_values=(0, 0))
        X_ev = hankel(padded_ev[: -filt_len + 1], padded_stiml[-filt_len:])

        padded_spikes = np.pad(
            binned_spikes[:-1], (filt_len, 0), constant_values=(0, 0)
        )
        X_sp = hankel(padded_spikes[: -filt_len + 1], padded_spikes[-filt_len:])

        X.append(np.hstack((X_sr, X_sl, X_sp, np.ones((X_sp.shape[0], 1)))))
        y.append(binned_spikes[:, np.newaxis])

    X_all.append(np.vstack(X))
    y_all.append(np.vstack(y))

X_all = np.array(X_all)
y_all = np.array(y_all)

number of trials: 233


In [12]:
# initial glm with no states to estimate weights
w_initial = np.empty((n_neurons, X_all.shape[2]))
for neuron in range(n_neurons):
    glm = sm.GLM(endog=y_all[neuron], exog=X_all[neuron], family=sm.families.Poisson())

    res = glm.fit(max_iter=1000, tol=1e-6, tol_criterion="params")

    w = res.params
    w_initial[neuron] = w

In [20]:
n_states = 3
n_trials = trial_indices.size
w_states = np.ones((n_states, n_neurons, X_all.shape[2])) * w_initial
# adding noise to all the weights
w_states += np.random.normal(0, 0.2, size=w_states.shape)
t_init = np.array([5, 1, 1])
# placing higher bias on self-transition
T = np.array([dirichlet(np.roll(t_init, n)) for n in range(n_states)])
alpha = np.empty((n_states, n_trials))
alpha[:, 0] = 1/n_states
betas = np.empty(alpha.shape)
betas[:, -1] = 1

for i in range(1):
    # expectation step:
    pass