In [5]:
import numpy as np
from pathlib import Path

DATA_DIR = Path.cwd().parent / 'data'

sessions = np.array([])
for file in sorted(DATA_DIR.glob("steinmetz_part*.npz")):
    print(f"Loading {file}...")
    session = np.load(DATA_DIR / file, allow_pickle=True)
    sessions = np.hstack((sessions, session['dat']))

with open(DATA_DIR / "selectors.npy", "rb") as f:
    selectors = np.load(f, allow_pickle=True)

Loading /home/jovyan/work/data/steinmetz_part0.npz...
Loading /home/jovyan/work/data/steinmetz_part1.npz...
Loading /home/jovyan/work/data/steinmetz_part2.npz...


In [25]:
for i in range(len(sessions)):
    session = sessions[i]
    sel = selectors[i]
    neuron_count = sum(sel["NEURON_VISUAL"])
    if neuron_count == 0:
        continue
    stim_count = sum(~sel["STIM_RIGHT_NONE"])
    ratio = stim_count / neuron_count
    print(f"Session {i:2d} - {neuron_count:3d} neurons - {stim_count:3d} trials w/ stimulus - {ratio:2.2f} ratio")

Session  0 - 178 neurons - 128 trials w/ stimulus - 0.72 ratio
Session  1 - 533 neurons - 136 trials w/ stimulus - 0.26 ratio
Session  2 - 228 neurons - 119 trials w/ stimulus - 0.52 ratio
Session  3 -  39 neurons - 142 trials w/ stimulus - 3.64 ratio
Session  6 - 101 neurons - 133 trials w/ stimulus - 1.32 ratio
Session  7 -  89 neurons - 148 trials w/ stimulus - 1.66 ratio
Session  8 - 221 neurons - 179 trials w/ stimulus - 0.81 ratio
Session  9 - 204 neurons - 202 trials w/ stimulus - 0.99 ratio
Session 10 - 275 neurons - 169 trials w/ stimulus - 0.61 ratio
Session 11 - 145 neurons - 173 trials w/ stimulus - 1.19 ratio
Session 12 -  66 neurons - 162 trials w/ stimulus - 2.45 ratio
Session 13 -  79 neurons - 140 trials w/ stimulus - 1.77 ratio
Session 16 -  12 neurons - 134 trials w/ stimulus - 11.17 ratio
Session 17 - 158 neurons - 153 trials w/ stimulus - 0.97 ratio
Session 18 - 179 neurons - 130 trials w/ stimulus - 0.73 ratio
Session 19 - 122 neurons - 149 trials w/ stimulus - 1.

In [41]:
DECODER_SESSION = 21
TIMES_VISION = np.arange(45, 100)

sel = selectors[DECODER_SESSION]
spikes = sessions[DECODER_SESSION]['spks'][sel["NEURON_VISUAL"]][:, :, TIMES_VISION]
contrast = sessions[DECODER_SESSION]['contrast_right']

In [55]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
X = spikes.transpose((1, 2, 0)).reshape(spikes.shape[1] * spikes.shape[2], spikes.shape[0])
Y = np.repeat(label_encoder.fit_transform(contrast), len(TIMES_VISION))

In [67]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

clf = LogisticRegression(penalty='l1', solver='saga', max_iter=5000)
scores = cross_val_score(clf, X, Y, cv=5)
print(scores)
scores.mean()

[0.52927928 0.54381654 0.55569206 0.5452498  0.53173628]


0.5411547911547911

In [68]:
clf = LogisticRegression(penalty='l1', solver='saga', max_iter=5000)
clf.fit(X, Y)

LogisticRegression(max_iter=5000, penalty='l1', solver='saga')

In [83]:
for i in range(len(X)):
    trial_num, r = divmod(i, len(TIMES_VISION))
    if trial_num == 10:
        break
        
    correct_class = label_encoder.transform(contrast[[trial_num]])[0]
    if r == 0:
        print("")
        print(f"Trial {trial_num:3d} - Correct Class {correct_class}")

    population_activity = X[[i], :]
    print(f"Class confidence at {(i-1)*10:4d}ms: {clf.predict_proba(population_activity)[0]}")


Trial   0 - Correct Class 3
Class confidence at  -10ms: [0.56098166 0.11313523 0.17236025 0.15352286]
Class confidence at    0ms: [0.56098166 0.11313523 0.17236025 0.15352286]
Class confidence at   10ms: [0.56158368 0.06136979 0.14622264 0.2308239 ]
Class confidence at   20ms: [0.57994057 0.09050675 0.1896775  0.13987518]
Class confidence at   30ms: [0.59899124 0.05948765 0.17216769 0.16935342]
Class confidence at   40ms: [0.56098166 0.11313523 0.17236025 0.15352286]
Class confidence at   50ms: [0.44616754 0.16050548 0.18824382 0.20508316]
Class confidence at   60ms: [0.45878312 0.07686821 0.18680208 0.27754659]
Class confidence at   70ms: [0.39305304 0.1298251  0.23095112 0.24617074]
Class confidence at   80ms: [0.56098166 0.11313523 0.17236025 0.15352286]
Class confidence at   90ms: [0.54986321 0.09442478 0.16291307 0.19279894]
Class confidence at  100ms: [0.39083037 0.08411063 0.18328393 0.34177508]
Class confidence at  110ms: [0.0242295  0.07877394 0.43711058 0.45988598]
Class con