In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import glob
import pickle
from scipy.linalg import hankel
from scipy.stats import dirichlet, poisson
from scipy.special import factorial, logsumexp
import statsmodels.api as sm
import matplotx
from multiprocessing import Pool
from time import time
from tqdm import tqdm
import os
import torch
from torch import nn
from torch.utils.data import DataLoader

plt.style.use(matplotx.styles.aura["dark"])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("test_data_acc_ind_492_0607.pickle", "rb") as handle:
    data = pickle.load(handle)
print(data.keys())
total_neurons = len(data["spikes"])
print(f"n_neurons: {total_neurons}")
print("n_trials: {}".format(data["choice"].size))

dict_keys(['nCues_RminusL', 'currMaze', 'laserON', 'trialStart', 'trialEnd', 'keyFrames', 'time', 'cueOnset_L', 'cueOnset_R', 'choice', 'trialType', 'spikes', 'timeSqueezedFR'])
n_neurons: 324
n_trials: 290


In [3]:
# constructing design matrix with all trials and all neurons
np.random.seed(0)
X_all = []
y_all = []
trial_indices = np.nonzero(data["currMaze"] > 0)[0]
print(f"number of trials: {trial_indices.size}")
filt_len = 30
sp_filt_len = 20
bin_size = 0.35
n_neurons = 20
neuron_idx = np.random.randint(0, total_neurons, size=n_neurons)
# neuron_idx = np.array([11, 12])
trial_id = (
    []
)  # this will keep track of each trial in the design matrix (since each trial spans multiple rows)
for i, neuron in enumerate(neuron_idx):
    X = []
    y = []

    for trial_idx in trial_indices:
        trial_start = data["trialStart"][trial_idx]
        trial_end = data["trialEnd"][trial_idx]
        trial_length = trial_end - trial_start

        keyframes = data["keyFrames"][trial_idx]
        keyframe_times = data["time"][trial_idx][keyframes.astype(int)].tolist()
        lcue_times = data["cueOnset_L"][trial_idx]
        rcue_times = data["cueOnset_R"][trial_idx]

        spikes = data["spikes"][neuron]
        spikes = spikes[(spikes > trial_start) * (spikes < trial_end)]

        bins = np.arange(0, trial_length, bin_size)
        bin_centers = np.convolve(bins, [0.5, 0.5], mode="valid")

        binned_stimr, _ = np.histogram(rcue_times, bins)
        binned_stiml, _ = np.histogram(lcue_times, bins)
        binned_ev = np.cumsum(binned_stimr) - np.cumsum(binned_stiml)
        binned_spikes, _ = np.histogram(spikes - trial_start, bins)

        padded_stimr = np.pad(binned_stimr, (filt_len - 1, 0), constant_values=(0, 0))
        X_sr = hankel(padded_stimr[: -filt_len + 1], padded_stimr[-filt_len:])
        padded_stiml = np.pad(binned_stiml, (filt_len - 1, 0), constant_values=(0, 0))
        X_sl = hankel(padded_stiml[: -filt_len + 1], padded_stiml[-filt_len:])
        padded_ev = np.pad(binned_ev, (filt_len - 1, 0), constant_values=(0, 0))
        X_ev = hankel(padded_ev[: -filt_len + 1], padded_stiml[-filt_len:])

        padded_spikes = np.pad(
            binned_spikes[:-1], (sp_filt_len, 0), constant_values=(0, 0)
        )
        X_sp = hankel(padded_spikes[: -sp_filt_len + 1], padded_spikes[-sp_filt_len:])

        X.append(np.hstack((X_sr, X_sl, np.ones((X_sp.shape[0], 1)))))
        y.append(binned_spikes[:, np.newaxis])
        if i == 0:
            trial_id.append(trial_idx * np.ones(X_sr.shape[0]))

    X_all.append(np.vstack(X))
    y_all.append(np.vstack(y))

trial_id = np.concatenate(trial_id)
X_all = np.array(X_all)
y_all = np.array(y_all)
print(y_all.shape, X_all.shape)
neuron_idx

number of trials: 290
(20, 13238, 1) (20, 13238, 61)


array([172,  47, 117, 192, 323, 251, 195,   9, 211, 277, 242, 292,  87,
        70,  88, 314, 193,  39,  87, 174])

In [20]:
class NeuralNetwork(nn.Module):
    def __init__(self, n_features, n_neurons):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(n_features, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, n_neurons),
            nn.ReLU(),
        )

    def forward(self, x):
        # print(x)
        # x = self.flatten(x)
        out = self.linear_relu_stack(x)
        return out

    def train(self, x, y, epochs=1000):
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-3)
        loss_fn = torch.nn.MSELoss()

        for i in range(epochs):
            pred = self.forward(x)
            loss = loss_fn(pred, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if i % 10 == 0:
                print(f"epoch {i} loss = {loss.item()}")


X = np.vstack(X)
model = NeuralNetwork(X.shape[1], y_all.shape[0])
model.train(
    torch.tensor(X, dtype=torch.float32),
    torch.tensor(y_all.squeeze().T, dtype=torch.float32),
)

epoch 0 loss = 10.522619247436523
epoch 10 loss = 10.513094902038574
epoch 20 loss = 10.503520011901855
epoch 30 loss = 10.493901252746582
epoch 40 loss = 10.484238624572754
epoch 50 loss = 10.47453784942627
epoch 60 loss = 10.464800834655762
epoch 70 loss = 10.455025672912598
epoch 80 loss = 10.445210456848145
epoch 90 loss = 10.435358047485352
epoch 100 loss = 10.425472259521484
epoch 110 loss = 10.415550231933594
epoch 120 loss = 10.40558910369873
epoch 130 loss = 10.395593643188477
epoch 140 loss = 10.385565757751465
epoch 150 loss = 10.375508308410645
epoch 160 loss = 10.3654146194458
epoch 170 loss = 10.35529613494873
epoch 180 loss = 10.345147132873535
epoch 190 loss = 10.334970474243164
epoch 200 loss = 10.324767112731934
epoch 210 loss = 10.31454086303711
epoch 220 loss = 10.304288864135742
epoch 230 loss = 10.294007301330566
epoch 240 loss = 10.28369426727295
epoch 250 loss = 10.27335262298584
epoch 260 loss = 10.262979507446289
epoch 270 loss = 10.252564430236816
epoch 280 l

In [16]:
torch.tensor(X)

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.]], dtype=torch.float64)