Importing modules

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statistics import mode
from pathlib import Path
from os.path import join as pjoin

import chiCa
from chiCa.visualization_utils import separate_axes
from spks.sync import load_ni_sync_data,interp1d
from spks.event_aligned import compute_firing_rate
from spks.clusters import Clusters

from utils import *
from viz import *


%matplotlib widget

Loading nidaq and behavior data

In [2]:
animal = 'GRB006'
session = '20240724_144439'
sessionpath = Path(f'/home/data/{animal}/{session}/')
sync_port = 0 # this is where the SMA of the probe is connected

(nionsets,nioffsets),(nisync,nimeta),(apsyncdata) = load_ni_sync_data(sessionpath=sessionpath)
aponsets = apsyncdata[0]['file0_sync_onsets'][6] # this should be the same for you, its where the sync is on the probe

corrected_onsets = {}
corrected_offsets = {} # This is a dictionary with the digital events that were connected to the breakout box.
for k in nionsets.keys():
    corrected_onsets[k] = interp1d(nionsets[sync_port],aponsets,fill_value='extrapolate')(nionsets[k]).astype('uint64')
    corrected_offsets[k] = interp1d(nionsets[sync_port],aponsets,fill_value='extrapolate')(nioffsets[k]).astype('uint64')
del k

# if you need analog channels those are in "nisync"
nitime = interp1d(nionsets[sync_port],aponsets,fill_value='extrapolate')(np.arange(len(nisync)))

# everything is in samples, use this sampling rate
srate = apsyncdata[0]['sampling_rate']  
t = nitime/srate
frame_rate = mode(1/(np.diff(corrected_onsets[1])/srate)) #corrected_onsets[1] are the frame samples, [2] are the trial start samples
analog_signal = nisync[:, 0] # analog stim signal

# storing digital events in seconds
trial_starts = corrected_onsets[2]/srate
left_port_entries = corrected_onsets[3]/srate
left_port_exits = corrected_offsets[3]/srate
center_port_entries = corrected_onsets[4]/srate
center_port_exits = corrected_offsets[4]/srate
right_port_entries = corrected_onsets[5]/srate
right_port_exits = corrected_offsets[5]/srate

# for simplicity later
port_events = {
    "center_port": {
        "entries": center_port_entries,
        "exits": center_port_exits
    },
    "left_port": {
        "entries": left_port_entries,
        "exits": left_port_exits
    },
    "right_port": {
        "entries": right_port_entries,
        "exits": right_port_exits
    }
}

behavior_data = chiCa.load_trialdata(pjoin(sessionpath, f'chipmunk/{animal}_{session}_chipmunk_DemonstratorAudiTask.mat'))

# get trialized timestamps for task events
trial_ts = get_trial_ts(trial_starts, detect_stim_events(t, srate, analog_signal, amp_threshold=5000), behavior_data, port_events)
trial_ts.insert(trial_ts.shape[1], 'response', trial_ts.apply(get_response_ts, axis=1))



In [8]:
timepoints = []
for _ , trial in trial_ts[trial_ts.trial_outcome.isin([0,1])].iterrows():
    initiation = trial.center_port_entries[-1]
    stimulus = trial.first_stim_ts
    action = trial.center_port_exits[-1]
    response = trial.response
    timepoints.append(np.array((initiation, stimulus, action, response)))

timepoints
# trial_ts
# trial_ts[trial_ts.trial_outcome.isin([0,1])]

# def get_response_ts(row):
#     """ Get the timestamp of when the animal responded, regardless of which side it was """
#     w = row['center_port_exits'][-1]  # withdrawal time
#     left = [entry for entry in row['left_port_entries'] if entry > w]
#     right = [entry for entry in row['right_port_entries'] if entry > w]
    
#     # get the first valid value or None if empty
#     left = left[0] if left else None
#     right = right[0] if right else None
    
#     # return the first valid (non-NaN) value
#     if pd.notna(left):
#         return left
#     elif pd.notna(right):
#         return right
#     else:
#         return None

# # Apply the combined function to each row and create the new column
# completed_trial_ts = trial_ts[trial_ts.trial_outcome.isin([0,1])]
# completed_trial_ts.insert(completed_trial_ts.shape[1], 'response', completed_trial_ts.apply(get_response_ts, axis=1))

[array([49.42583333, 49.48232267, 50.15056667, 50.98206667]),
 array([53.2484    , 53.33515467, 54.0963    , 54.90473333]),
 array([58.0771    , 58.124064  , 58.64396667, 59.18003333]),
 array([64.036     , 64.06207177, 64.5551    , 65.0009    ]),
 array([74.00513333, 74.06644267, 74.5605    , 75.3307    ]),
 array([83.6187    , 83.68405467, 84.29856667, 85.1819    ]),
 array([ 99.36243333,  99.40779223,  99.9704    , 100.55526667]),
 array([108.49016667, 108.58282933, 109.0856    , 109.6203    ]),
 array([119.1728    , 119.24308572, 119.76033333, 120.64553333]),
 array([124.4568    , 124.48260173, 125.1391    , 125.79143333]),
 array([132.9837    , 133.02044214, 133.67873333, 134.19176667]),
 array([147.65436667, 147.72556191, 148.23843333, 148.98543333]),
 array([150.84693333, 150.90383702, 151.53283333, 152.2363    ]),
 array([155.43213333, 155.52212637, 156.1614    , 156.9379    ]),
 array([176.519     , 176.589188  , 177.1185    , 178.12063333]),
 array([187.64646667, 187.681504  

In [None]:
c = 0
idx = []
for i , row in trial_ts[trial_ts.trial_outcome.isin([0,1])].iterrows():
    # print(f'{len(row.center_port_entries)}\n')
    if len(row.left_port_entries) > 1:
        c += 1
        idx.append(i)
# c
df = trial_ts[trial_ts.index.isin(idx)]

def get_response_ts(row):
    """ Get the timestamp of when the animal responded, regardless of which side it was """
    w = row['center_port_exits'][-1]  # withdrawal time
    left = [entry for entry in row['left_port_entries'] if entry > w]
    right = [entry for entry in row['right_port_entries'] if entry > w]
    
    # get the first valid value or None if empty
    left = left[0] if left else None
    right = right[0] if right else None
    
    # return the first valid (non-NaN) value
    if pd.notna(left):
        return left
    elif pd.notna(right):
        return right
    else:
        return None

# Apply the combined function to each row and create the new column
df.insert(df.shape[1], 'response', df.apply(get_response_ts, axis=1))
df

In [None]:
df.shape[1]

Loading and filtering KS results

In [None]:
kilosort_path = Path(f'/home/data/{animal}/{session}/kilosort2.5/imec0/')
sc = np.load(pjoin(kilosort_path, 'spike_clusters.npy')) #KS clusters
ss = np.load(pjoin(kilosort_path, 'spike_times.npy')) #KS spikes (in samples)
st = ss/srate #conversion from spike samples to spike times

clu = Clusters(folder = kilosort_path, get_waveforms=False, get_metrics=True, load_template_features=True)

good_unit_ids, n_units = get_good_units(clusters_obj = clu, spike_clusters = sc)


Static population PSTH plotting

In [None]:
binwidth_ms = 5
tpre = 0.025
tpost = 0.055

population_timestamps = st[good_unit_ids]
single_unit_timestamps = get_cluster_spike_times(spike_times = st, spike_clusters = sc, good_unit_ids = good_unit_ids)

psth = get_population_firing_rate(event_times = np.hstack(trial_ts.first_stim_ts),
                                  spike_times = single_unit_timestamps,
                                  tpre = tpre,
                                  tpost = tpost,
                                  binwidth_ms = binwidth_ms)
n_stims = len(psth)

plt.figure(figsize=(4, 4))
plot_psth(compute_mean_sem(psth), tpre, tpost, binwidth_ms, 'time from first stim event (s)', 'population firing rate (Hz)', f"{n_units} units - {n_stims} stims")
separate_axes(plt.gca())

In [None]:
dd['initiation']

In [None]:
binwidth_ms = 50
tpre = 0
tpost = 1

data = trial_ts[trial_ts.trial_outcome.isin([0,1])].iloc[0]
dd = dict({'initiation' : data.center_port_entries[0] - 1,
           'stimulus' : data.first_stim_ts,
           'action' : center_port_exits[0],
           'outcome' : right_port_entries[0]})

psth = get_population_firing_rate(event_times = [dd['outcome']],
                                  spike_times = single_unit_timestamps,
                                  tpre = tpre,
                                  tpost = tpost,
                                  binwidth_ms = binwidth_ms)
n_stims = len(psth)

plt.figure(figsize=(4, 4))
plot_psth(compute_mean_sem(psth), tpre, tpost, binwidth_ms, 'time from first stim event (s)', 'population firing rate (Hz)', f"{n_units} units - {n_stims} stims")
# separate_axes(plt.gca())

Stim responses by outcome

In [None]:
#loop over all outcomes, plot pop psth for those stim events, and save fig
binwidth_ms = 5
tpre = 0.025
tpost = 0.055

plt.figure(figsize=(5, 5))
for outcome, c in zip(np.unique(trial_ts.trial_outcome), ['b', 'k', 'r', 'y']):
    ts = np.hstack(trial_ts[trial_ts.trial_outcome == outcome].stim_ts)
    unit_fr = []
    with suppress_print():
        for i in range(len(single_unit_timestamps)):
            try:
                unit_fr.append(compute_firing_rate(ts, single_unit_timestamps[i], tpre, tpost, binwidth_ms, kernel=None)[0])
            except:
                unit_fr.append(np.nan)
    psth = np.mean(unit_fr, axis = 0)
    # psth, _ = compute_firing_rate(ts, population_timestamps, tpre, tpost, binwidth_ms, kernel=None)
    n_stims = len(psth)

    if outcome == 0:
        txt = 'unrewarded'
    elif outcome == 1:
        txt = 'rewarded'
    elif outcome == -1:
        txt = 'early withdrawal'
    elif outcome == 2:
        txt = 'no choice'

    plot_psth(compute_mean_sem(psth), tpre, tpost, binwidth_ms,
              xlabel = 'time from first stim event (s)',
              ylabel = 'spike rate (Hz)',
              fig_title = f"{n_units} neurons",
              data_label = f'{txt} {n_stims} stims',
              color=c)
    # filename = f"pop_stim_kernel_{rate}_Hz.png"
    # save_dir = Path('/home/gabriel/lib/lab-projects/ephys/figures/stim_kernels_per_rate/')
    # filepath = os.path.join(save_dir, filename)
    # plt.savefig(filepath)
plt.legend()
del outcome, ts, psth, n_stims

In [None]:
# #For plotting response to first four stim events by outcome

# binwidth_ms = 5
# tpre = 0.025
# tpost = 0.055

# #loop over all outcomes, plot pop psth for those stim events, and save fig
# fig, axs = plt.subplots(1,4,figsize=(16,6), sharex = True, sharey = True)
# fig.subplots_adjust(hspace=0)
# for i, ax in enumerate(axs):
#     for outcome, c in zip(np.unique(trial_ts.trial_outcome), ['b', 'k', 'r', 'y']):
#         ts = np.hstack(trial_ts[trial_ts.trial_outcome == outcome].stim_ts.apply(lambda x: get_nth_element(x, i)))
#         # psth, _ = compute_firing_rate(ts, population_timestamps, tpre, tpost, binwidth_ms, kernel=None)
#         unit_fr = []
#         for ii in range(len(single_unit_timestamps)):
#             with suppress_print():
#                 unit_fr.append(compute_firing_rate(ts, single_unit_timestamps[ii], tpre, tpost, binwidth_ms, kernel=None)[0])
#         pop_rate = np.mean(unit_fr, axis = 0)

#         n_stims = pop_rate.shape[0]

#         if outcome == 0:
#             txt = 'unrewarded'
#         elif outcome == 1:
#             txt = 'rewarded'
#         elif outcome == -1:
#             txt = 'early withdrawal'
#         elif outcome == 2:
#             txt = 'no choice'

#         plot_psth(compute_mean_sem(pop_rate), tpre, tpost, binwidth_ms,
#                 xlabel = 'time from each stim event (s)',
#                 ylabel = 'population spike rate (Hz)',
#                 data_label = f'{txt} {n_stims} stims',
#                 fig_title = f'{i}-th stim onset',
#                 color=c,
#                 ax = ax,
#                 tight=False)
#     ax.legend()

# del outcome, ts, n_stims

Single neuron interactive viewer

In [None]:
binwidth_ms = 5
tpre = 0.025
tpost = 0.3


plt.figure(figsize=(4, 4))
# fig, axs = plt.subplots(1,2, figsize=(6,4))
individual_psth_viewer(event_times = trial_ts.first_stim_ts, 
                       single_unit_timestamps = single_unit_timestamps, 
                       pre_seconds = tpre, 
                       post_seconds = tpost, 
                       binwidth_ms = binwidth_ms, 
                       save_dir = Path('/home/gabriel/lib/lab-projects/ephys/figures/'),
                       fig_title = 'first stim onsets')

# individual_psth_viewer(event_times = trial_ts.first_stim_ts, 
#                        single_unit_timestamps = single_unit_timestamps, 
#                        pre_seconds = tpre, 
#                        post_seconds = tpost, 
#                        binwidth_ms = binwidth_ms, 
#                        save_dir = Path('/home/gabriel/lib/lab-projects/ephys/figures/'),
#                        fig_title = 'first stim onsets',
#                        ax = axs[1])

In [None]:
np.hstack(trial_ts.first_stim_ts)[:5]

In [None]:
plt.figure()
plt.hist(np.diff(stim_ts), np.arange(0, 0.7, 0.02), color='k', alpha = 0.5)
plt.vlines(np.arange(0, 0.7, 0.02), 0, 100, color='k', alpha = 1, linestyles='dotted')
plt.xlabel('inter stim event interval (s)')
plt.ylabel('count')
plt.title(f'{animal} - {session}')
plt.text(0.6, 1250, '0.02 s bins')
plt.tight_layout()
# plt.plot(np.diff(stim_ts))

In [None]:
trial_ts