## Indexing of Steinmetz data

In [1]:
#@title Data retrieval
import os, requests
from pathlib import Path
DATA_DIR = Path.cwd().parent / 'data'

fname = []
for j in range(3):
  fname.append(DATA_DIR / f"steinmetz_part{j}.npz")
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not fname[j].exists():
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

In [2]:
#@title Data loading
import numpy as np

alldat = np.array([])
for j in range(len(fname)):
  alldat = np.hstack((alldat, np.load(DATA_DIR / f"steinmetz_part{j}.npz", allow_pickle=True)['dat']))

In [3]:
def get_action_times(wheel, window_size=5, bins_before=50, bins_after=5):
    """
    Finds the action-aligned indices for each trial.
    
    Movement time is determined by taking the difference between wheel
    positions in successive time bins and using a sliding leading window to
    find the first point where all time bins within the window contain movement
    
    Returns an array of size num_trials x (bin_before + bins_after).
    The returned indices are guaranteed to fit within the available data, but
    might not be perfectly aligned with actual movement if it starts too soon
    or too late.
    
    Use np.take_along_axis with spike data and returned indices
    
    Arguments:
    wheel -- array of wheel positions for each trial
    
    Keyword Arguments:
    window_size -- sliding window size
    bins_before -- number of bins before first movement point to include
    bins_after -- number of bins after first movement point to include
    """
    num_bins = wheel.shape[2]
    
    wheel_speed = np.abs(np.diff(wheel[0], axis=1))
    windows = np.arange(wheel_speed.shape[1] - window_size)[:, np.newaxis] + np.arange(window_size)
    is_moving = (wheel_speed > 0)[:, windows].min(axis=2)
    action_times = is_moving.argmax(axis=1)[:, np.newaxis] + np.arange(-bins_before, bins_after)
    
    # Fit within available window
    action_times -= np.minimum(action_times[:, [0]], 0)
    assert(np.all(action_times >= 0))
    action_times += np.minimum(num_bins - action_times[:, [-1]] - 1, 0)
    assert(np.all(action_times[:, -1] < num_bins))
    
    return action_times[np.newaxis, :, :]

In [4]:
from brain_areas import *

# Regions in Figure 3 with >= ~10% of visually-selective neurons
VISUAL_AREAS = [
    AREA_VISp,
    AREA_VISl,
    AREA_VISpm,
    AREA_VISam,
    AREA_CP,
    AREA_LD,
    AREA_SCs,
]

selectors = []
for dat in alldat:
    sel = {
        "NEURON_VISUAL": np.isin(dat['brain_area'], VISUAL_AREAS),
        "CHOICE_RIGHT": dat['response'] == -1,
        "CHOICE_LEFT": dat['response'] == 1,
        "CHOICE_NONE": dat['response'] == 0,
        "STIM_RIGHT": dat['contrast_right'] > dat['contrast_left'],
        "STIM_LEFT": dat['contrast_right'] < dat['contrast_left'],
        "STIM_EQUAL": (dat['contrast_right'] == dat['contrast_left']) & (dat['contrast_right'] > 0),
        "STIM_NONE": (dat['contrast_right'] == dat['contrast_left']) & (dat['contrast_right'] == 0),
        "STIM_RIGHT_HIGH": dat['contrast_right'] == 1,
        "STIM_RIGHT_MEDIUM": dat['contrast_right'] == 0.5,
        "STIM_RIGHT_LOW": dat['contrast_right'] == 0.25,
        "STIM_RIGHT_NONE": dat['contrast_right'] == 0,
        "TIMES_ACTION": get_action_times(dat['wheel'], window_size=5, bins_before=50, bins_after=5),
    }
    sel.update({
        "CHOICE_CORRECT": (
            (sel["STIM_RIGHT"] & sel["CHOICE_RIGHT"])
            | (sel["STIM_LEFT"] & sel["CHOICE_LEFT"])
            | (sel["STIM_NONE"] & sel["CHOICE_NONE"])
        ),
        "CHOICE_MISS": ~sel["STIM_NONE"] & sel["CHOICE_NONE"]
    })
    selectors.append(sel)

In [5]:
for i in range(len(alldat)):
    sel = selectors[i]
    spikes = alldat[i]['spks']
    print("Selecting visual neurons in trials where stimulus right was high and mouse chose left")
    print(spikes[sel["NEURON_VISUAL"]][:, sel["STIM_RIGHT_HIGH"] & sel["CHOICE_LEFT"]].shape)
    
    print("Selecting action-aligned time bins from all trials")
    print(np.take_along_axis(spikes, sel["TIMES_ACTION"], 2).shape)
    
    print("")

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(178, 11, 250)
Selecting action-aligned time bins from all trials
(734, 214, 55)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(533, 8, 250)
Selecting action-aligned time bins from all trials
(1070, 251, 55)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(228, 7, 250)
Selecting action-aligned time bins from all trials
(619, 228, 55)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(39, 3, 250)
Selecting action-aligned time bins from all trials
(1769, 249, 55)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(0, 4, 250)
Selecting action-aligned time bins from all trials
(1077, 254, 55)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(0, 6, 250)
Selecting action-aligned time bins from all trials
(1169, 290, 55)

In [6]:
with open(DATA_DIR / "selectors.npy", "wb") as f:
    np.save(f, selectors)