This script analyzes optostim neuropixel data and looks for significant responses during stimulation compared to baseline.

1. Loads relevant data
    - Preprocessed data from server
    - ECoG latency by stim site
2. Calculates significant responses
    - two-tailed t-test between FR during stim vs. before
3. Plots PSTH and rasters for significantly modulated units
4. Finds and plots waveforms of responsive units
5. Plots the LFP response.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import aopy
import os
import pandas as pd
from db import dbfunctions as db
from ipywidgets import interactive, widgets
import scipy
import h5py
from tqdm.auto import tqdm 
import seaborn as sn
import sklearn
from sklearn.decomposition import PCA, FactorAnalysis
from itertools import compress
import multiprocessing as mp
import time
import math
from scipy.fft import fft
import glob
from datetime import date

In [2]:
# Paths
ecog_signal_path = "/home/aolab/gdrive/Lab equipment/electrophysiology/210910_ecog_signal_path.xlsx"
elec_to_pos = "/home/aolab/gdrive/Lab equipment/electrophysiology/our signal path definitions/244ch_viventi_ecog_elec_to_pos.xlsx"
data_path_preproc = '/media/moor-data/preprocessed.new/'
data_path_raw = '/media/moor-data/raw/'
# save_dir = "/home/aolab/gdrive/People/RyanCanfield/Results/beignet_analysis/optostim_population_dynamics"

# General data parameters
subject = 'beignet'
task_coords = 'yzx'
task_perturb = None
task_rotation = 0

# Neuropixel data parameters
implant_name = ['NP_Insert72', 'NP_Insert137']
start_date = '2023-07-13'
end_date = date.today()
elec_config = 'bottom'
spike_bin_width_mc = 0.005 #[s]

# Alignment parameters
tbefore = 0.25
tafter = 0.25
min_trial_prop = 0.75
min_fr = 2

# Visualization parameters
colors = sn.color_palette(n_colors=9)
recording_brain_areas={'M1': [30, 55, 47, 40], 'PM':[11, 9, 18]}
day_colors=['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8']

In [3]:
def calc_autocorr_func(data, lags=None):
    '''
    Args:
        data (ntime, ...):
        lags (nlag): Indicies of lags to analyze. If None, compute autocorrelation at each possible lag.
    '''
    
    ntime = data.shape[0]
    if lags is None:
        lags = np.arange(-ntime+1, ntime-1)
    nlags = len(lags)
    
    autocorr_func = np.zeros(nlags)*np.nan
    for ilag, lag in enumerate(lags):
        # print(data.shape, np.roll(data,lag,axis=0).shape)
        # autocorr_func.append(np.apply_along_axis(np.correlate, 0, data, np.roll(data,lag,axis=0)))
        # print(data.shape, np.roll(data,lag,axis=0).shape)
        autocorr_func[ilag] = np.correlate(data, np.roll(data,lag,axis=0))
    return autocorr_func, lags

def load_parsed_ksdata(kilosort_dir, data_dir):
    '''
    load kilosort data (spike indices, clusters, and label) parsed into the task entries
    This data is not still synchronized
    
    Args:
        kilosort_dir (str): kilosort directory (ex. '/data/preprocessed/kilosort')
        data_dir (str): data directory that contains parsed data (ex. '2023-06-30_Neuropixel_ks_affi_bottom_port1_9847')
        
    Returns:
        spike_indices (nspikes): spike indices detected by kilosort (not spike times)
        spike_clusters (nspikes): unit label detected  by kilsort
    '''
    
    # Path for loading spikes and clusters
    data_path = os.path.join(kilosort_dir, data_dir)
    spike_path = os.path.join(data_path,'spike_indices_entry.npy')
    cluster_path = os.path.join(data_path,'spike_clusters_entry.npy')
    label_path = os.path.join(data_path,'ks_label.npy')

    # Load spikes and clusters
    spike_indices = np.load(spike_path)
    spike_clusters = np.load(cluster_path)
    ks_label = np.load(label_path)
    
    return spike_indices, spike_clusters, ks_label

def classify_ks_unit(spike_times, spike_label):
    '''
    Classify unit activity identified by kilosort into each single unit
    
    Args:
        spike_times (nspikes): spike times generated by kilosort
        spike_label (nspikes): cluster labels of each spike generated by kilosort
        
    Returns:
        spike_times_unit (dict): spike times for each unit
    '''
    
    spike_times_unit = {}
    
    for unit_label in np.unique(spike_label):
        spike_times_unit[f'{unit_label}'] = spike_times[spike_label==unit_label.astype(int)]
    
    return spike_times_unit

def calc_ks_waveforms(raw_data, sample_rate, spike_times_unit, templates, channel_pos, waveforms_nch=10, time_before=1000., time_after=1000.):
    '''
    Calculate waveforms, waveform channels, and positions of units, using templates from kilosort
    
    args:
        raw_data (nt,nch): time series neural data to detect spikes and extract waveforms from.
        sample_rate (float): sampling rate (Hz)
        spike_times_unit (dict): spike times for each unit identified by kilosort
        templates (n_unit, n_points, nch): templates that kilosort used to detect spikes
        channel_pos (nch, 2): channel positions
        waveforms_nch (int, optional): the number of channels with large amplitude of templates
        time_before (float, optional): time [us] to include before the start of each trial
        time_after (float, optional): time [us] to include after the start of each trial
    
    returns
        tuple: tuple containing:
            | **unit_waveforms (dict):** waveforms for each unit. The shape is (nspikes,  m_points, waveforms_nch)
            | **unit_waveforms_ch (n_unit, waveforms_nch):** large amplitude channels in templates
            | **unit_pos (dict):** channel positions of each unit
    '''
    
    time_before *= 1e-6
    time_after *= 1e-6
    nch = channel_pos.shape[0]
    duration = int((time_before + time_after)*sample_rate)
    
    unit_waveforms_ch = {}
    unit_waveforms = {}
    unit_pos = {}

    for iunit, unit in enumerate(spike_times_unit.keys()):
        # Look at high amplitude channels in templates
        amp_template_ch = np.zeros(nch)
        for ich in range(nch):
            amp_template_ch[ich] = np.max(templates[int(unit),:,ich])-np.min(templates[int(unit),:,ich]) # don't use iunit instead of int(unit)

        # Sort high amplitude channels and save channels and their positions
        large_amp_ch = np.argsort(amp_template_ch)[::-1][:waveforms_nch]
        unit_waveforms_ch[f'{unit}'] = large_amp_ch
        unit_pos[f'{unit}'] = channel_pos[large_amp_ch[0],:]

        # Get waveforms in high amplitude channels for each spike
        unit_times = spike_times_unit[f'{unit}']
        waveforms = np.zeros((unit_times.shape[0],duration,waveforms_nch))*np.nan
        for ispike, unit_time in enumerate(unit_times):
            start = int((unit_time - time_before)*sample_rate)
            end = start + duration

            if np.logical_and(end < raw_data.shape[0], start >= 0): # Ensure there are enough data points to grab the waveform
                for ich, ch in enumerate(large_amp_ch):
                    waveforms[ispike,:,ich] = raw_data[start:end,ch]
                    
        unit_waveforms[f'{unit}'] = waveforms
    
    return unit_waveforms, unit_waveforms_ch, unit_pos

In [4]:
# Load neuropixel center-out task data
opto_entries =  db.get_task_entries(subject__name=subject, task__name='laser only', date=(start_date, end_date))
opto_entries = [me for me in opto_entries if 'neuropixel_port1_drive_type' in me.task_params and me.task_params['neuropixel_port1_drive_type'] in implant_name
                and me.entry_name != 'flash']
print(opto_entries)
print(len(opto_entries))

[2023-08-28 13:49:56.350383: beignet on laser only task, id=10798, 2023-08-28 13:51:04.406791: beignet on laser only task, id=10799, 2023-08-28 14:56:47.537136: beignet on laser only task, id=10803, 2023-08-28 15:02:15.970548: beignet on laser only task, id=10804, 2023-08-28 15:09:03.965473: beignet on laser only task, id=10805, 2023-08-28 15:14:22.097289: beignet on laser only task, id=10806, 2023-08-29 14:56:56.049109: beignet on laser only task, id=10810, 2023-08-29 16:25:27.846484: beignet on laser only task, id=10813, 2023-08-29 16:43:03.703551: beignet on laser only task, id=10814, 2023-08-30 11:35:11.143187: beignet on laser only task, id=10818, 2023-08-30 13:08:16.278506: beignet on laser only task, id=10821, 2023-08-30 13:26:00.981570: beignet on laser only task, id=10822, 2023-08-30 13:44:10.289583: beignet on laser only task, id=10823, 2023-08-31 12:55:53.587739: beignet on laser only task, id=10826, 2023-08-31 12:57:02.026928: beignet on laser only task, id=10827, 2023-08-3

In [5]:
time_axis = np.linspace(-tbefore, tafter, int((tafter+tbefore)/spike_bin_width_mc))
ntime = len(time_axis)
# ntime = int((tafter + tbefore)//spike_bin_width_mc)+1
trial_time_axis = np.arange(-tbefore, tafter, spike_bin_width_mc)

In [11]:
start = time.time()
spike_times = []
unit_labels = []
laser_times = []
laser_widths = []
spike_segs = []
spike_align = []
spike_labels = []
recording_site = []
stim_site = []
good_opto_entries = []
autocorr_func = []
for ioe, oe in enumerate(tqdm(opto_entries)):

    ########################################################
    ### Load spike times across days
    ########################################################
    session_number = 0
 
    # Load data
    exp_data, exp_metadata = aopy.data.load_preproc_exp_data(data_path_preproc, subject, oe.id, oe.date.date())
    filename_opto = aopy.data.get_preprocessed_filename(subject, oe.id, oe.date.date(), 'ap')
    try:
        ap_data = aopy.data.load_hdf_group(os.path.join(data_path_preproc, subject), filename_opto, 'ap')
        ap_metadata = aopy.data.load_hdf_group(os.path.join(data_path_preproc, subject), filename_opto, 'metadata')
        laser_info = aopy.preproc.bmi3d.get_laser_trial_times(data_path_preproc, subject, oe.id, oe.date.date())
    except:
        continue

    ########################################################
    unit_labels.append(ap_data['unique_label'])
    spike_times.append(ap_data['unit']) # Assumes spike labels are consistent across recording sessions (works if recorded on the same day, but otherwise it does not)
    laser_times.append(laser_info[0])
    laser_widths.append(laser_info[1])
    stim_site.append(exp_metadata['stimulation_site'])
    recording_site.append(exp_metadata['neuropixel_port1_site'])
    
    ########################################################
    
    spike_segs_day = {}
    spike_align_day = np.zeros((ntime, len(laser_info[0]), len(ap_data['unique_label']))) # (ntime, ntrials, nunits)
    autocorr_func_day = []
    for iunit, unitid in enumerate(ap_data['unique_label']):
        binned_spikes, time_bins = aopy.precondition.bin_spike_times(ap_data['unit'][str(unitid)], 0, laser_info[0][-1]+10, spike_bin_width_mc)
        autocorr_func_day.append(calc_autocorr_func(ap_data['unit'][str(unitid)], lags=None)[0])
        
        # Trial align 
        spike_align_day[:,:,iunit] = aopy.preproc.trial_align_data(binned_spikes, laser_info[0], tbefore, tafter, 1/spike_bin_width_mc)[:,0,:]


    spike_segs.append(spike_segs_day)
    spike_align.append(spike_align_day)
    good_opto_entries.append(oe)
    autocorr_func.append(autocorr_func_day)
    # print(100*ioe/len(opto_entries), '%')

  0%|          | 0/19 [00:00<?, ?it/s]

  autocorr_func[ilag] = np.correlate(data, np.roll(data,lag,axis=0))


In [None]:
# Identify which units are stable by being active on a high proportion of trials and above a low minimum FR

stable_unit_lbl = []
stable_unit_mask = []
stable_unit_idx = []
target_idx_good = []
target_idx_ordered_mask = []
# Get stable unit idx and labels
for ie, entry in enumerate(good_opto_entries):
    min_trials = len(laser_times[ie])*min_trial_prop
    stable_unit_lbl.append(unit_labels[ie][np.where(np.logical_and(np.sum(np.max(spike_align[ie]>0, axis=0), axis=0)>min_trials, np.mean(spike_align[ie], axis=(0,1))>min_fr))[0]])
    stable_unit_mask.append(np.logical_and(np.sum(np.max(spike_align[ie]>0, axis=0), axis=0)>min_trials, np.mean(spike_align[ie], axis=(0,1))>min_fr))
    stable_unit_idx.append(np.where(stable_unit_mask[ie])[0])
nstable_unit_idx = [len(stable_unit_idx[id]) for id in range(len(good_opto_entries))]

In [None]:
offset = 2
# Calculate significance of response
laser_start_idx = [(np.zeros(len(laser_times[ioe]))+np.where(trial_time_axis>=0)[0][0]).astype(int)+offset for ioe in range(len(good_opto_entries))]
laser_stop_idx = [laser_start_idx[ioe]+np.round(laser_widths[ioe]/spike_bin_width_mc).astype(int)-(2*offset) for ioe in range(len(good_opto_entries))]
baseline_start_idx = [(np.zeros(len(laser_times[ioe]))+np.where(trial_time_axis>=0)[0][0]).astype(int)-(laser_stop_idx[ioe]-laser_start_idx[ioe])-offset for ioe in range(len(good_opto_entries))]
baseline_stop_idx = [(np.zeros(len(laser_times[ioe]))+np.where(trial_time_axis>=0)[0][0]).astype(int)-offset for ioe in range(len(good_opto_entries))]

unit_resp_sig = []
for ioe in range(len(good_opto_entries)):
    print(ioe)
    temp_unit_sig = np.zeros(nstable_unit_idx[ioe])*np.nan
    for iunit, unit_lbl in enumerate(stable_unit_idx[ioe]):
        null_points = []
        alt_points = []
        [null_points.extend(spike_align[ioe][baseline_start_idx[ioe][itrial]:baseline_stop_idx[ioe][itrial],itrial,unit_lbl]) for itrial in range(len(laser_times[ioe]))]
        [alt_points.extend(spike_align[ioe][laser_start_idx[ioe][itrial]:laser_stop_idx[ioe][itrial],itrial,unit_lbl]) for itrial in range(len(laser_times[ioe]))]
        _, temp_unit_sig[iunit] = scipy.stats.ttest_ind(np.array(null_points), np.array(alt_points))
    
    unit_resp_sig.append(temp_unit_sig)

In [None]:
# Control for multiple comparisons using false discovery rate
from statsmodels.stats.multitest import fdrcorrection
resp_unit_idx = [stable_unit_idx[ioe][fdrcorrection(unit_resp_sig[ioe])[0]] for ioe in range(len(good_opto_entries))]
nresp_unit = [len(resp_unit_idx[ioe]) for ioe in range(len(good_opto_entries))]

In [None]:
for ioe in range(len(good_opto_entries)):
# Plot rasters aligned to go cue organized by trial
    ncol = 7
    nrow = (len(resp_unit_idx[ioe])//ncol)+1
    fig, ax = plt.subplots(nrow, ncol, figsize=(27, nrow*3))
    for iplot, iunit in enumerate(resp_unit_idx[ioe]):
        irow = iplot//ncol
        icol = iplot % ncol
        
        try:
            # unit2plot = np.where(unit_labels[idate]==[iunit])[0][0]
            ax[irow, icol].plot(time_axis, np.mean(spike_align[ioe][:,:,iunit], axis=1))
            ax[irow, icol].plot([time_axis[int(np.median(laser_start_idx[ioe]))], time_axis[int(np.median(laser_start_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            ax[irow, icol].plot([time_axis[int(np.median(laser_stop_idx[ioe]))], time_axis[int(np.median(laser_stop_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            ax[irow, icol].plot([time_axis[int(np.median(baseline_start_idx[ioe]))], time_axis[int(np.median(baseline_start_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'g--', linewidth=2)
            ax[irow, icol].plot([time_axis[int(np.median(baseline_stop_idx[ioe]))], time_axis[int(np.median(baseline_stop_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'g--', linewidth=2)
            # ax[irow, icol].plot([0,0], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            # ax[irow, icol].plot([np.median(laser_widths[ioe]),np.median(laser_widths[ioe])], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            ax[irow, icol].set(xlabel='Time [s]', ylabel='Firing Rate', title=f'Rec {recording_site[ioe]} - Unit {iunit}')
        except:
            ax[icol].plot(time_axis, np.mean(spike_align[ioe][:,:,iunit], axis=1))
            ax[icol].plot([time_axis[int(np.median(laser_start_idx[ioe]))], time_axis[int(np.median(laser_start_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            ax[icol].plot([time_axis[int(np.median(laser_stop_idx[ioe]))], time_axis[int(np.median(laser_stop_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            ax[icol].plot([time_axis[int(np.median(baseline_start_idx[ioe]))], time_axis[int(np.median(baseline_start_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'g--', linewidth=2)
            ax[icol].plot([time_axis[int(np.median(baseline_stop_idx[ioe]))], time_axis[int(np.median(baseline_stop_idx[ioe]))]], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'g--', linewidth=2)
            # ax[icol].plot([0,0], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            # ax[icol].plot([np.median(laser_widths[ioe]),np.median(laser_widths[ioe])], [0,np.max(np.mean(spike_align[ioe][:,:,iunit], axis=1))], 'r--', linewidth=2)
            ax[icol].set(xlabel='Time [s]', ylabel='Firing Rate', title=f'Rec {recording_site[ioe]} - Unit {iunit}')
            
    plt.suptitle(f'Recording: {ioe} - Site: {recording_site[ioe]}')
    fig.tight_layout()
    plt.show()