# Auxiliary notebook for a SUA properties calculation script

## Import

In [5]:
import os 
os.chdir('/CSNG/studekat/ripple_band_project/code')

In [6]:
from functions_analysis import *
import pandas as pd
import numpy as np
import yaml
import pickle
import neo

In [7]:
import warnings
warnings.simplefilter(action='ignore', category=pd.errors.SettingWithCopyWarning)

## Parameters

In [10]:
with open("/CSNG/studekat/ripple_band_project/code/params_analysis.yml") as f:
    params = yaml.safe_load(f)
### AUX = params['aux']

DATA_FOLDER = params['data_folder'] ### folder with all the preprocessed data
DATES = params['dates']
#PEAK_BORDERS = params['phase_peak_borders']
WIDTH_INTERVALS = params['width_intervals'] #[(0,7),(8,12),(13,90)]
FINAL_CLASSES = ['DOWN_narrow_peak','DOWN_narrow_other','DOWN_medium','DOWN_wide','UP_peak','UP_other']

DF_FOLDER = '/CSNG/studekat/ripple_band_project/dataframes' ### here the resulting dataframes will be saved
MONKEY_LIST = ['L','N','F','A']

## Data handling sandbox

In [5]:
block = load_block('N',1,type_rec='RS',type_sig='spikes',date=DATES['N']['RS'][0],data_folder=DATA_FOLDER)

In [6]:
len(block.segments[0].spiketrains)

18

In [15]:
monkey = 'L'
array = 1
date = '20180806'

In [16]:
spike_block = load_block(monkey,array,type_rec='OG',type_sig='spikes',date=date,data_folder=DATA_FOLDER) ### SUA
RB_block = load_block(monkey,array,type_rec='OG',type_sig='RB',date=date,data_folder=DATA_FOLDER)
num_cells = len(spike_block.segments[0].spiketrains)
start_t_spikes_ms = int(np.floor(np.float64(spike_block.segments[0].spiketrains[0].t_start.magnitude)*1000))
start_t_RB_ms = int(np.floor(np.float64(RB_block.segments[0].analogsignals[0].t_start.magnitude)*1000))
print(f'Start t RB: {start_t_RB_ms}')
print(f'Start t spikes: {start_t_spikes_ms}')
if start_t_spikes_ms!=start_t_RB_ms:
    print('Spikes and ripples do not have the same start time.')

Start t RB: 3
Start t spikes: 0
Spikes and ripples do not have the same start time.


In [None]:
RB_block

## Functions

In [None]:
def spike_train_prop_vec(spike_vector,rb_phase,rb_envelope,rb_env_phase,channel_prop=None,indicator=None,indicator_name='EC'):
    """
    THIS IS NOT THE LATEST VERSION.

    
    TODO better description
    
    Calculates properties of one spiketrain.

    spike_train is a spike train directly from the nix file, with all metadata included
    """
    from copy import deepcopy 

    if indicator is not None:
        mask = indicator>0
        spike_vector = spike_vector[mask]
        rb_phase = rb_phase[mask]
        rb_envelope = rb_envelope[mask]
        rb_env_phase = rb_env_phase[mask]
        
    ### lists of RB phases and envs. of spikes - CAREFUL WITH MULTIPLE SPIKES IN ONE MS
    phases_list = []
    env_list = []
    phases_env_list = []
    aux_sp = deepcopy(spike_vector)
    while np.sum(aux_sp)>0:
        non_zero_idx = np.where(aux_sp>0)[0]
        for idx in non_zero_idx:
            phases_list.append(rb_phase[idx]) 
            env_list.append(rb_envelope[idx]) 
            phases_env_list.append(rb_env_phase[idx]) 
            aux_sp[idx]-=1 ### subtracting spikes that we have already used

    ### phase preference
    r, phi = circular_avg(np.array(phases_list),bins=30)
    r_env, phi_env = circular_avg(np.array(phases_env_list),bins=30)

    ### phases of high and low envelope spikes 
    high_env_mask = np.array(env_list)>=np.median(rb_envelope) #### mask from the envelope values for each spike (NOT IN SHAPE OF INPUT ARRAYS, only spikes)
    low_env_mask = np.array(env_list)<np.median(rb_envelope)
    
    list_phases_high_env = np.array(phases_list)[high_env_mask]
    list_phases_low_env = np.array(phases_list)[low_env_mask]

    list_env_phases_high_env = np.array(phases_env_list)[high_env_mask]
    list_env_phases_low_env = np.array(phases_env_list)[low_env_mask]

    ### firing rate 
    dur_rec_ms = spike_vector.shape[0]
    dur_rec_s = dur_rec_ms/1000
    fr = np.sum(spike_vector)/dur_rec_s
    fr_high_env = len(list_phases_high_env)/dur_rec_s*2 ### we normalise by 2, because this only considers spikes above median env.
    fr_low_env = len(list_phases_low_env)/dur_rec_s*2
    
    ### CV ISI
    len_intervals = count_zero_intervals(spike_vector) 
    CV_ISI = np.std(np.array(len_intervals))/np.mean(np.array(len_intervals))

    if indicator is None:
        ind_string = ''
    else:
        if indicator_name is not None:
            ind_string = f'_{indicator_name}'
        else:
            print('No indicator name given.')
            return
        
    prop_dict = {f'FR{ind_string}':fr, 
                f'CV_ISI{ind_string}': CV_ISI,
                f'ISI{ind_string}': len_intervals,

                f'env_th_median{ind_string}':np.median(rb_envelope), ### the median value of RB envelope on this channel
                 
                f'list_phases{ind_string}': phases_list,
                f'list_env{ind_string}':env_list,
                f'list_env_phases{ind_string}':phases_env_list,

                f'list_phases_high_env{ind_string}':list_phases_high_env,
                f'list_env_phases_high_env{ind_string}':list_env_phases_high_env,
                 
                f'list_phases_low_env{ind_string}':list_phases_low_env,
                f'list_env_phases_low_env{ind_string}':list_env_phases_low_env,

                f'FR_high_env_median{ind_string}':fr_high_env,
                f'FR_low_env_median{ind_string}':fr_low_env, 
                f'FR_high_env_low_env_median_ratio{ind_string}':fr_high_env/fr_low_env, 
                 
                f'pref_phase_spikes{ind_string}':phi, 
                f'norm_RB_phase_selectivity_spikes{ind_string}':r, 
                f'pref_env_phase_spikes{ind_string}': phi_env,
                f'norm_RB_env_phase_selectivity_spikes{ind_string}': r_env,
                }
    # adding other percentile TH values, so the spikes can be splitted into high/low env. in different ways later
    for perc in [10,20,30,40,50,60,70,80,90,95]:
        prop_dict[f'env_th_perc_{perc}{ind_string}'] = np.percentile(rb_envelope,perc)
    
    if channel_prop is not None:
        for k in channel_prop.keys():
            prop_dict[k] = channel_prop[k]
    return prop_dict

In [8]:
def aux_add_up_down_classes(df_sua):
    """
    Clasifies whether the peak is UP or DOWN (bigger in abs. val above, or below 0), in the zscored waveform.
    """
    df_added = df_sua
    aux_classes = []
    for idx in df_added.index:
        wf = df_added.loc[idx]['avg_wf']
        if np.abs(np.max(wf))>np.abs(np.min(wf)):
            aux_classes.append('UP')
        else:
            aux_classes.append('DOWN')

    df_added['wf_direction'] = aux_classes
    return df_added

In [9]:
def aux_add_waveform_prop(df_sua):
    """
    From the dataframe with formated waveform properties calculates waveform width and height (amplitude).
    """
    df_added = df_sua
    ### amplitude
    waveforms = df_sua['avg_wf'].values
    df_added['amp_wf'] = [np.max(wf) - np.min(wf) for wf in waveforms]
    ### distance from peak to trough
    min_idcs = [np.argmin(wf) for wf in waveforms]
    df_added['width_wf'] = [np.abs(np.argmax(wf[min_idx:])+min_idx - np.argmin(wf)) for wf, min_idx in zip(waveforms,min_idcs)]
    return df_added

In [10]:
def aux_add_zscored_avg_waveform(df_sua):
    """
    From a dataframe with formated waveforms, saves one more column with each waveform zscored, 
    and another column with its zscored amplitude.
    """
    from scipy.stats import zscore
    
    waveforms = df_sua['avg_wf'].values
    df_added = df_sua
    df_added['avg_wf_zscored'] =  [zscore(wf.magnitude) for wf in waveforms]
    wfs_zsc = df_added['avg_wf_zscored'].values
    df_added['amp_wf_zscored'] = [np.max(wf) - np.min(wf) for wf in wfs_zsc]
    
    return df_added

In [11]:
def aux_add_width_classes(df_sua,width_intervals = WIDTH_INTERVALS):
    """
    Adding width class info, based on the measured width of a waveform (peak to the right max.).
    """
    names_widths = ['narrow','medium','wide']
    df_added = df_sua
    ### adding column with spike width classification into narrow, medium, wide
    aux_classes = []
    for idx in df_added.index:
        width_row = df_added.loc[idx]['width_wf']
        for i in range(len(width_intervals)):
            interval = width_intervals[i]
            if (width_row>=interval[0]) & (width_row<=interval[1]):
                aux_classes.append(names_widths[i])
    
    df_added['width_wf_class'] = np.array(aux_classes)
    return df_added

In [32]:
### classification into 6 classes
def aux_add_final_classes(df_sua,peak_borders=PEAK_BORDERS,final_classes=FINAL_CLASSES):
    df_sua['final_class'] = 'NO_CLASS'
    dict_cl_indices = {}
    for cl in final_classes:
        if cl=='DOWN_narrow_peak':
            aux_df = df_sua[df_sua['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_wf_class']=='narrow']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[mask_peak]
            dict_cl_indices['DOWN_narrow_peak'] = aux_df.index
        elif cl=='DOWN_narrow_other':
            aux_df = df_sua[df_sua['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_wf_class']=='narrow']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[~mask_peak]
            dict_cl_indices['DOWN_narrow_other'] = aux_df.index
        elif cl=='DOWN_medium':
            aux_df = df_sua[df_sua['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_wf_class']=='medium']
            dict_cl_indices['DOWN_medium'] = aux_df.index
        elif cl=='DOWN_wide':
            aux_df = df_sua[df_sua['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_wf_class']=='wide']
            dict_cl_indices['DOWN_wide'] = aux_df.index
        elif cl=='UP_peak':
            aux_df = df_sua[df_sua['wf_direction']=='UP']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[mask_peak]
            dict_cl_indices['UP_peak'] = aux_df.index
        elif cl=='UP_other':
            aux_df = df_sua[df_sua['wf_direction']=='UP']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[~mask_peak] 
            dict_cl_indices['UP_other'] = aux_df.index
        else:
            print('Undefined cell type.')
    for cl in final_classes:
        for i in dict_cl_indices[cl]:
            df_sua.loc[i,'final_class'] = cl
    return df_sua

## Dataframe calculation - the first part, computationaly expensive

In [13]:
if False:
    for monkey in MONKEY_LIST:
        print(monkey)
        for date in params['dates'][monkey]['RS']:
            print(date)
            prop_list = []
            for array in range(1,17): 
                print(array)
                try:
                    ### loading SUA spike trains and RB block
                    try:
                        spike_block = load_block(monkey,array,type_rec='RS',type_sig='spikes',date=date,data_folder=DATA_FOLDER) ### SUA
                        RB_block = load_block(monkey,array,type_rec='RS',type_sig='RB',date=date,data_folder=DATA_FOLDER)
                        num_cells = len(spike_block.segments[0].spiketrains)
                        start_t_spikes_ms = int(np.floor(np.float64(spike_block.segments[0].spiketrains[0].t_start.magnitude)*1000))
                        start_t_RB_ms = int(np.floor(np.float64(RB_block.segments[0].analogsignals[0].t_start.magnitude)*1000))
                        print(f'Start t RB: {start_t_spikes_ms}')
                        print(f'Start t spikes: {start_t_RB_ms}')
                        if start_t_spikes_ms!=start_t_RB_ms:
                            print('Spikes and ripples do not have the same start time.')
                    except:
                        print(f'Cannot read the spike file for date {date}, monkey {monkey}, array {array}.')
                    try:
                        df_OP = pd.read_csv(f'{DATA_FOLDER}/metadata/OP_maps_dataframes/{monkey}/OP_prop_OG_array{array}.csv')
                    except:
                        print(f'Cannot read OP maps for date {date}, monkey {monkey}, array {array}.')
                    for cell in range(num_cells):
                        spike_train = spike_block.segments[0].spiketrains[cell]
                        cell_name = spike_train.annotations['nix_name']
                        electrode_ID = spike_train.annotations['Electrode_ID']
                        
                        ### channel prop - additional info for a channel, such as OP, bad channel ID, array and area
                        channel_prop = {}
                        channel_prop['cell_name'] = cell_name
                        ### OP
                        try:
                            ch_OP = df_OP[df_OP['Electrode_ID']==electrode_ID]
                            if ch_OP['selectivity_01'].values[0]>0.2 and ch_OP['num_f0_high_jump'].values[0]<3:
                                channel_prop['pref_OP'] = ch_OP['pref_OP'].values[0]
                                channel_prop['selectivity_OP_01'] = ch_OP['selectivity_01'].values[0]
                            else:
                                channel_prop['pref_OP'] = np.nan
                                channel_prop['selectivity_OP_01'] = ch_OP['selectivity_01'].values[0]
                        except:
                            channel_prop['pref_OP'] = np.nan
                            channel_prop['selectivity_OP_01'] = np.nan
                        ### channel order
                        ch = aux_electrodeID_to_ch_order(monkey,date,electrode_ID,array,data_folder=DATA_FOLDER,type_rec='RS')
                        channel_prop['channel_order'] = ch
                        ### array
                        channel_prop['array'] = array
                        ### area
                        if monkey in ['N','F']:
                            name_area = 'Area'
                        else:
                            name_area = 'cortical_area'
                        ch_area = spike_train.annotations[name_area]
                        channel_prop['area'] = ch_area
                        ### order in the spike train
                        channel_prop['train_order'] = cell

                        rb_phase_arr = sig_block_to_arr(RB_block,'RB_phase')
                        rb_envelope_arr = sig_block_to_arr(RB_block,'RB_envelope_norm')
                        rb_env_phase_arr = sig_block_to_arr(RB_block,'RB_envelope_phase')

                        spike_arr = spike_block_to_arr(spike_block)

                        ### cutting out common times only for N and F
                        if monkey in ['N','F']:
                            rb_phase_arr = cut_abs_times(rb_phase_arr,start_t_RB_ms,monkey,rec_type='RS',date=date,params=params)
                            rb_envelope_arr = cut_abs_times(rb_envelope_arr,start_t_RB_ms,monkey,rec_type='RS',date=date,params=params)
                            rb_env_phase_arr = cut_abs_times(rb_env_phase_arr,start_t_RB_ms,monkey,rec_type='RS',date=date,params=params)
                            spike_arr = cut_abs_times(spike_arr,start_t_spikes_ms,monkey,rec_type='RS',date=date,params=params)
                        
                        rb_phase = rb_phase_arr[ch,:]
                        rb_envelope = rb_envelope_arr[ch,:]
                        rb_env_phase = rb_env_phase_arr[ch,:]
                        spike_vector = spike_arr[cell,:]
                        
                        prop_dict = spike_train_prop_vec(spike_vector,rb_phase,rb_envelope,rb_env_phase,channel_prop=channel_prop) ### input already binned spikes
                        prop_list.append(prop_dict)
                except:
                    print(f'For array {array}, the SUA properties were not calculated.')
            df_prop = pd.DataFrame(prop_list)
            ensure_dir_exists(f'{DF_FOLDER}/sua_prop/')
            df_prop.to_pickle(f'{DF_FOLDER}/sua_prop/monkey{monkey}_all_arrays_date_{date}.pkl')


## Adding other properties and formating to the DF (computationaly easier)

In [38]:
for monkey in MONKEY_LIST: 
    print(monkey)
    all_RS_dates = params['dates'][monkey]['RS']
    for date in all_RS_dates:
        print(date)
        with open(f'{DF_FOLDER}/sua_prop/monkey{monkey}_all_arrays_date_{date}.pkl', "rb") as file:
            df_sua = pickle.load(file)
        df_added = aux_add_waveform_prop(df_sua)
        df_added = aux_add_zscored_avg_waveform(df_added)
        df_added = df_added[df_added['channel_order']>-1] ### erasing not working arrays
        df_added = aux_add_width_classes(df_added,width_intervals=WIDTH_INTERVALS)
        df_added = aux_add_up_down_classes(df_added)
        df_added = aux_add_final_classes(df_added,peak_borders=PEAK_BORDERS,final_classes=FINAL_CLASSES)

        #### saving new dataframes with properties as pickle
        ensure_dir_exists(f'{DF_FOLDER}/sua_prop_all/')
        df_added.to_pickle(f'{DF_FOLDER}/sua_prop_all/monkey{monkey}_all_arrays_date_{date}.pkl')
        ### the copy warning is there only for the case of empty arrays, no worries about it

L
20170725
20170809
20170810
N
20240719_B1
20240719_B2
F
20240122_B1
20241216_B1
A
20190815
20190816


In [34]:
df_added['final_class']

0              DOWN_wide
1       DOWN_narrow_peak
2      DOWN_narrow_other
3              DOWN_wide
4       DOWN_narrow_peak
             ...        
499          DOWN_medium
500    DOWN_narrow_other
501          DOWN_medium
502          DOWN_medium
503          DOWN_medium
Name: final_class, Length: 504, dtype: object

In [35]:
df_added.keys()

Index(['FR', 'CV_ISI', 'ISI', 'env_th_median', 'list_phases', 'list_env',
       'list_env_phases', 'list_phases_high_env', 'list_env_phases_high_env',
       'list_phases_low_env', 'list_env_phases_low_env', 'FR_high_env_median',
       'FR_low_env_median', 'FR_high_env_low_env_median_ratio',
       'pref_phase_all_spikes', 'norm_phase_sel_01_all_spikes',
       'pref_env_phase_all_spikes', 'norm_env_phase_sel_01_all_spikes',
       'env_th_perc_10', 'env_th_perc_20', 'env_th_perc_30', 'env_th_perc_40',
       'env_th_perc_50', 'env_th_perc_60', 'env_th_perc_70', 'env_th_perc_80',
       'env_th_perc_90', 'env_th_perc_95', 'cell_name', 'pref_OP',
       'selectivity_OP_01', 'channel_order', 'array', 'area', 'train_order',
       'avg_wf', 'amp_wf', 'width_wf', 'avg_wf_zscored', 'amp_wf_zscored',
       'width_wf_class', 'wf_direction', 'final_class'],
      dtype='object')

In [36]:
df_added.head()

Unnamed: 0,FR,CV_ISI,ISI,env_th_median,list_phases,list_env,list_env_phases,list_phases_high_env,list_env_phases_high_env,list_phases_low_env,...,area,train_order,avg_wf,amp_wf,width_wf,avg_wf_zscored,amp_wf_zscored,width_wf_class,wf_direction,final_class
0,7.107476,1.731616,"[37, 29, 27, 33, 19, 13, 48, 34, 264, 6, 22, 1...",0.898995,"[0.8949215, -1.6516726, 1.9185883, 0.4335102, ...","[3.580378, 0.61513877, 3.6505191, 3.8867853, 1...","[-0.3918044, -2.3370738, 0.32232544, 0.4845442...","[0.8949215, 1.9185883, 0.4335102, -2.2834458, ...","[-0.3918044, 0.32232544, 0.48454422, 1.9539795...","[-1.6516726, 2.6324809, -2.312513, 2.8026843, ...",...,V1,0,"[-0.07893936 uV, 0.009604962 uV, 0.104669385 u...",12.182018 uV,13,"[-0.050143387, -0.014451916, 0.02386775, 0.065...",4.910469,wide,DOWN,DOWN_wide
1,7.566047,1.56521,"[120, 27, 173, 44, 90, 40, 84, 20, 40, 86, 17,...",1.314136,"[-0.52945757, -2.989535, 1.7835573, 2.7233076,...","[5.884698, 4.13634, 0.63234925, 1.6020149, 1.8...","[-0.46864057, 0.5003421, 2.345138, 1.2908846, ...","[-0.52945757, -2.989535, 2.7233076, -1.6106553...","[-0.46864057, 0.5003421, 1.2908846, -1.0384673...","[1.7835573, -1.2493649, 0.7245242, 2.5869122, ...",...,V1,1,"[0.044280175 uV, 0.04797673 uV, 0.054886095 uV...",5.9538403 uV,4,"[0.058323715, 0.06307975, 0.07196944, 0.077603...",7.660294,narrow,DOWN,DOWN_narrow_peak
2,4.527008,1.010187,"[27, 1, 99, 43, 145, 118, 90, 33, 10, 57, 243,...",1.088546,"[1.5148422, 2.7526646, 0.07995539, 0.6984406, ...","[3.41308, 3.7781377, 5.009789, 1.7155644, 1.41...","[-1.0516787, -0.9338419, 0.3463526, 1.1438483,...","[1.5148422, 2.7526646, 0.07995539, 0.6984406, ...","[-1.0516787, -0.9338419, 0.3463526, 1.1438483,...","[-0.18523695, -1.6737292, -0.3542702, 0.176791...",...,V1,2,"[-0.00018575166 uV, 0.016559677 uV, 0.04155632...",5.0304017 uV,5,"[-0.0076714493, 0.0167015, 0.05308408, 0.08355...",7.321743,narrow,DOWN,DOWN_narrow_other
3,7.880076,1.005134,"[2, 60, 168, 115, 12, 301, 206, 50, 183, 22, 1...",1.088546,"[1.2034215, -2.660083, -1.9840807, -0.50284547...","[2.0073838, 1.7620276, 0.60708755, 0.54813105,...","[-1.3341041, 0.9382869, 1.8672955, -3.1077988,...","[1.2034215, -2.660083, 2.875721, 0.20242125, 0...","[-1.3341041, 0.9382869, -1.4294983, -0.6522088...","[-1.9840807, -0.50284547, 2.3088422, -1.892213...",...,V1,3,"[-0.13725783 uV, -0.14345522 uV, -0.1325626 uV...",4.0109205 uV,20,"[-0.20722252, -0.2161767, -0.2004387, -0.15798...",5.795106,wide,DOWN,DOWN_wide
4,5.458824,1.158895,"[9, 63, 9, 268, 151, 188, 202, 141, 103, 170, ...",1.041889,"[-0.59491074, 0.16845854, 2.704239, 0.37656453...","[2.0509684, 1.8092152, 0.5843474, 1.5266342, 1...","[-1.17618, -1.2795454, -2.8362763, -1.4785236,...","[-0.59491074, 0.16845854, 0.37656453, -2.05416...","[-1.17618, -1.2795454, -1.4785236, -2.0827224,...","[2.704239, 3.1325738, -0.1967004, -1.6087348, ...",...,V1,4,"[0.025834981 uV, 0.027140008 uV, 0.025227027 u...",5.086688 uV,5,"[0.031549763, 0.033435132, 0.030671455, 0.0297...",7.348716,narrow,DOWN,DOWN_narrow_peak
