In [2]:
#@title Data retrieval
import os, requests, pdb

fname = []
for j in range(3):
  fname.append('steinmetz_part%d.npz'%j)
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not os.path.isfile(fname[j]):
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

In [25]:
#@title Import matplotlib and set defaults, import scikit learn
from matplotlib import rcParams 
from matplotlib import pyplot as plt

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] =15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

In [4]:
#@title Data loading
import numpy as np

alldat = np.array([])
for j in range(len(fname)):
    alldat = np.hstack((alldat, np.load('steinmetz_part%d.npz'%j, allow_pickle=True)['dat']))

In [2]:
def getChoiceCells(dat):
    #make array of choices. left choices = 1, right choices = 0, ignore no-go trial
    choice_areas = ['PL', 'MOs', 'MOp', 'CP', 'SNr', 'ZI', 'SCp']
    bareas = dat['brain_area']
    bareas = np.isin(bareas, choice_areas)
    LorR = dat['response']!=0 
    choices = dat['response'][LorR]
    choices[choices == -1] = 0
    
    #turn dat['spks'] into [n_neurons, n_trials] array of normalized firing rates on each trial
    bins = 0.01 #10ms bins
    trial_FR = np.sum(dat['spks'], axis=2) / (np.shape(dat['spks'])[2] * bins / 1000)
    norm_FR = trial_FR[:,LorR] / np.mean(trial_FR) #array of normalized FR for each neuron in area on L or R choice trials (n_neurons x n_trials)
    
    # Define logistic regression model
    log_reg = LogisticRegression(penalty="none").fit(norm_FR.T, choices)
    y_pred = log_reg.predict(norm_FR.T)
    accuracy = (choices == y_pred).mean()
    
    most_predictive = np.quantile(log_reg.coef_, [0.30, 0.70])
    ipsi_cells = np.squeeze(np.logical_and(log_reg.coef_ > most_predictive[1], bareas))
    ipsi_cells = np.squeeze(np.nonzero(ipsi_cells))
    contra_cells = np.squeeze(np.logical_and(log_reg.coef_ < most_predictive[0], bareas))
    contra_cells = np.squeeze(np.nonzero(contra_cells))

    #create output matrix for glm
    trials = np.shape(dat['spks'])[1]
    rows = trials * 5
    ipsi_spikes = np.empty([rows, len(ipsi_cells)])
    contra_spikes = ipsi_spikes
    for cell in range(len(ipsi_cells)):
        for t in range(trials):
            summed_spikes = np.cumsum(dat['spks'][ipsi_cells[cell], t, :])[::50]
            to_subtract = np.insert(summed_spikes[:4],0,0)
            ipsi_spikes[t*5:t*5+5, cell] = summed_spikes - to_subtract
            
            summed_spikes2 = np.cumsum(dat['spks'][contra_cells[cell], t, :])[::50]
            to_subtract2 = np.insert(summed_spikes2[:4],0,0)
            contra_spikes[t*5:t*5+5, cell] = np.subtract(summed_spikes2, to_subtract2)
            
    return ipsi_spikes, contra_spikes

def plot_weights(models, sharey=True):
    """Draw a stem plot of weights for each model in models dict."""
    n = len(models)
    f = plt.figure(figsize=(10, 2.5 * n))
    axs = f.subplots(n, sharex=True, sharey=sharey)
    axs = np.atleast_1d(axs)

    for ax, (title, model) in zip(axs, models.items()):

        ax.margins(x=.02)
        stem = ax.stem(model.coef_.squeeze(), use_line_collection=True)
        stem[0].set_marker(".")
        stem[0].set_color(".2")
        stem[1].set_linewidths(.5)
        stem[1].set_color(".2")
        stem[2].set_visible(False)
        ax.axhline(0, color="C3", lw=3)
        ax.set(ylabel="Weight", title=title)
    ax.set(xlabel="Neuron (a.k.a. feature)")
    f.tight_layout()

In [None]:
ipsi_spikes, contra_spikes = getChoiceCells(dat)
np.shape(ipsi_spikes)
print(contra_spikes)