## Indexing of Steinmetz data

In [1]:
#@title Data retrieval
import os, requests
from pathlib import Path
DATA_DIR = Path.cwd().parent / 'data'

fname = []
for j in range(3):
  fname.append(DATA_DIR / f"steinmetz_part{j}.npz")
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not fname[j].exists():
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

In [2]:
#@title Data loading
import numpy as np

alldat = np.array([])
for j in range(len(fname)):
  alldat = np.hstack((alldat, np.load(DATA_DIR / f"steinmetz_part{j}.npz", allow_pickle=True)['dat']))

In [3]:
def get_action_times(
    wheel,
    response_time,
    bins_before=20,
    bins_after=5,
    move_min=5,
    move_window=5,
    stim_time=50,
    stim_buffer=20,
):
    """
    Finds the action-aligned indices for each trial.
    
    Movement time is determined by taking the peak-to-peak wheel position in
    a sliding leading window to find the time bin initiating a large, sustained
    movement.
    
    Returns an array of size num_trials x (bin_before + bins_after).
    The returned indices are guaranteed to fit within the available data, (i.e.
    movement time will not be before stimulus time or extend past the response
    time) but might not be perfectly aligned with actual movement if it starts
    too soon or too late.
    
    Use np.take_along_axis with spike data and returned indices
    
    Arguments:
    wheel -- array of wheel positions for each trial
    
    Keyword Arguments:
    bins_before -- number of bins before first movement point to include
    bins_after -- number of bins after first movement point to include
    move_window -- sliding window size in which movement will be searched
    move_min -- minimum amount of peak-to-peak wheel movement required in window
    stim_time -- bin number of stimulus presentation
    stim_buffer -- number of mandatory "no movement" bins before stim_time
    """
    num_bins = wheel.shape[2]
    windows = np.arange(stim_time, num_bins - move_window)[:, np.newaxis] + np.arange(move_window)
    is_moving = np.ptp(wheel[0][:, windows], axis=2) > move_min
    action_time = is_moving.argmax(axis=1) + stim_time
    
    # Fit within available window
    min_time = stim_time - stim_buffer
    action_time = np.maximum(action_time, min_time + bins_before)
    
    max_time = (response_time / 0.01).astype(np.int).flatten() + stim_time
    action_time = np.minimum(action_time, max_time - bins_after)
    
    return action_time[np.newaxis, :, np.newaxis] + np.arange(-bins_before, bins_after)

In [4]:
from brain_areas import *

# Regions in Figure 3 with >= ~10% of visually-selective neurons
VISUAL_AREAS = [
    AREA_VISp,
    AREA_VISl,
    AREA_VISpm,
    AREA_VISam,
    AREA_CP,
    AREA_LD,
    AREA_SCs,
]

selectors = []
for dat in alldat:
    no_movement = np.ptp(dat['wheel'][0], 1) < 3
    sel = {
        "NEURON_VISUAL": np.isin(dat['brain_area'], VISUAL_AREAS),
        "CHOICE_RIGHT": (dat['response'] == -1) & (~no_movement),
        "CHOICE_LEFT": (dat['response'] == 1) & (~no_movement),
        "CHOICE_NONE": (dat['response'] == 0) | (no_movement),
        "STIM_RIGHT": dat['contrast_right'] > dat['contrast_left'],
        "STIM_LEFT": dat['contrast_right'] < dat['contrast_left'],
        "STIM_EQUAL": (dat['contrast_right'] == dat['contrast_left']) & (dat['contrast_right'] > 0),
        "STIM_NONE": (dat['contrast_right'] == dat['contrast_left']) & (dat['contrast_right'] == 0),
        "STIM_RIGHT_HIGH": dat['contrast_right'] == 1,
        "STIM_RIGHT_MEDIUM": dat['contrast_right'] == 0.5,
        "STIM_RIGHT_LOW": dat['contrast_right'] == 0.25,
        "STIM_RIGHT_NONE": dat['contrast_right'] == 0,
        "TIMES_ACTION": get_action_times(
            dat['wheel'],
            dat['response_time'],
            bins_before=20,
            bins_after=5,
            move_window=4,
            move_min=2,
            stim_time=50,
            stim_buffer=20
        ),
    }
    sel.update({
        "CHOICE_CORRECT": (
            (sel["STIM_RIGHT"] & sel["CHOICE_RIGHT"])
            | (sel["STIM_LEFT"] & sel["CHOICE_LEFT"])
            | (sel["STIM_NONE"] & sel["CHOICE_NONE"])
        ),
        "CHOICE_MISS": ~sel["STIM_NONE"] & sel["CHOICE_NONE"]
    })
    selectors.append(sel)

In [5]:
for i in range(len(alldat)):
    sel = selectors[i]
    spikes = alldat[i]['spks']
    print("Selecting visual neurons in trials where stimulus right was high and mouse chose left")
    print(spikes[sel["NEURON_VISUAL"]][:, sel["STIM_RIGHT_HIGH"] & sel["CHOICE_LEFT"]].shape)
    
    print("Selecting action-aligned time bins from all trials")
    print(np.take_along_axis(spikes, sel["TIMES_ACTION"], 2).shape)
    
    print("")

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(178, 11, 250)
Selecting action-aligned time bins from all trials
(734, 214, 25)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(533, 8, 250)
Selecting action-aligned time bins from all trials
(1070, 251, 25)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(228, 7, 250)
Selecting action-aligned time bins from all trials
(619, 228, 25)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(39, 3, 250)
Selecting action-aligned time bins from all trials
(1769, 249, 25)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(0, 4, 250)
Selecting action-aligned time bins from all trials
(1077, 254, 25)

Selecting visual neurons in trials where stimulus right was high and mouse chose left
(0, 6, 250)
Selecting action-aligned time bins from all trials
(1169, 290, 25)

In [6]:
with open(DATA_DIR / "selectors.npy", "wb") as f:
    np.save(f, selectors)