# Auxiliary notebook for a SUA properties calculation script

## Import

In [1]:
import os 
os.chdir('/CSNG/studekat/ripple_band_project/code')

In [2]:
from functions_analysis import *
import pandas as pd
import numpy as np
import yaml
import pickle
import neo

In [3]:
import warnings
warnings.simplefilter(action='ignore', category=pd.errors.SettingWithCopyWarning)

## Parameters

In [35]:
with open("/CSNG/studekat/ripple_band_project/code/params_analysis.yml") as f:
    params = yaml.safe_load(f)
### AUX = params['aux']

DATA_FOLDER = params['data_folder'] ### folder with all the preprocessed data
DATES = params['dates']
PEAK_BORDERS = params['phase_peak_borders']
WIDTH_INTERVALS = params['width_intervals'] #[(0,7),(8,12),(13,90)]
FINAL_CLASSES = ['DOWN_narrow_peak','DOWN_narrow_other','DOWN_medium','DOWN_wide','UP_peak','UP_other']

DF_FOLDER = '/CSNG/studekat/ripple_band_project/dataframes' ### here the resulting dataframes will be saved
MONKEY_LIST = ['L','N','F','A']

## Data handling sandbox

In [5]:
block = load_block('N',1,type_rec='RS',type_sig='spikes',date=DATES['N']['RS'][0],data_folder=DATA_FOLDER)

In [6]:
len(block.segments[0].spiketrains)

18

## Functions

In [7]:
def spike_train_prop_vec(spike_vector,rb_phase,rb_envelope,rb_env_phase,channel_prop=None):
    """
    Calculates properties of one spiketrain.

    spike_train is a spike train directly from the nix file, with all metadata included
    """
    from copy import deepcopy 
    
    ### lists of RB phases and envs. of spikes - CAREFUL WITH MULTIPLE SPIKES IN ONE MS
    phases_list = []
    env_list = []
    phases_env_list = []
    aux_sp = deepcopy(spike_vector)
    while np.sum(aux_sp)>0:
        non_zero_idx = np.where(aux_sp>0)[0]
        for idx in non_zero_idx:
            phases_list.append(rb_phase[idx]) 
            env_list.append(rb_envelope[idx]) 
            phases_env_list.append(rb_env_phase[idx]) 
            aux_sp[idx]-=1 ### subtracting spikes that we have already used

    ### phase preference
    r, phi = circular_avg(np.array(phases_list),bins=30)
    r_env, phi_env = circular_avg(np.array(phases_env_list),bins=30)

    ### phases of high and low envelope spikes 
    high_env_mask = np.array(env_list)>=np.median(rb_envelope) #### mask from the envelope values for each spike (NOT IN SHAPE OF INPUT ARRAYS, only spikes)
    low_env_mask = np.array(env_list)<np.median(rb_envelope)
    
    list_phases_high_env = np.array(phases_list)[high_env_mask]
    list_phases_low_env = np.array(phases_list)[low_env_mask]

    list_env_phases_high_env = np.array(phases_env_list)[high_env_mask]
    list_env_phases_low_env = np.array(phases_env_list)[low_env_mask]

    ### firing rate 
    dur_rec_ms = spike_vector.shape[0]
    dur_rec_s = dur_rec_ms/1000
    fr = np.sum(spike_vector)/dur_rec_s
    fr_high_env = len(list_phases_high_env)/dur_rec_s*2 ### we normalise by 2, because this only considers spikes above median env.
    fr_low_env = len(list_phases_low_env)/dur_rec_s*2
    
    ### CV ISI
    len_intervals = count_zero_intervals(spike_vector) 
    CV_ISI = np.std(np.array(len_intervals))/np.mean(np.array(len_intervals))

    ### average waveform
    #avg_waveform = np.mean(spike_train.waveforms,axis=0)

    prop_dict = {'FR':fr, 
                'CV_ISI': CV_ISI,
                'ISI': len_intervals,

                'env_th_median':np.median(rb_envelope), ### the median value of RB envelope on this channel
                 
                'list_phases': phases_list,
                'list_env':env_list,
                'list_env_phases':phases_env_list,

                'list_phases_high_env':list_phases_high_env,
                'list_env_phases_high_env':list_env_phases_high_env,
                 
                'list_phases_low_env':list_phases_low_env,
                'list_env_phases_low_env':list_env_phases_low_env,

                'FR_high_env_median':fr_high_env,
                'FR_low_env_median':fr_low_env, 
                'FR_high_env_low_env_median_ratio':fr_high_env/fr_low_env, 
                 
                'pref_phase_all_spikes':phi, 
                'norm_phase_sel_01_all_spikes':r, 
                'pref_env_phase_all_spikes': phi_env,
                'norm_env_phase_sel_01_all_spikes': r_env,
                 
                #'avg_wf': avg_waveform,
    }
    ### adding other percentile TH values, so the spikes can be splitted into high/low env. in different ways later
    for perc in [10,20,30,40,50,60,70,80,90,95]:
        prop_dict[f'env_th_perc_{perc}'] = np.percentile(rb_envelope,perc)
    
    if channel_prop is not None:
        for k in channel_prop.keys():
            prop_dict[k] = channel_prop[k]
    return prop_dict

In [8]:
def aux_add_up_down_classes(df_sua):
    """
    Clasifies whether the peak is UP or DOWN (bigger in abs. val above, or below 0), in the zscored waveform.
    """
    df_added = df_sua
    aux_classes = []
    for idx in df_added.index:
        wf = df_added.loc[idx]['avg_wf']
        if np.abs(np.max(wf))>np.abs(np.min(wf)):
            aux_classes.append('UP')
        else:
            aux_classes.append('DOWN')

    df_added['wf_direction'] = aux_classes
    return df_added

In [9]:
def aux_add_waveform_prop(df_sua):
    """
    From the dataframe with formated waveform properties calculates waveform width and height (amplitude).
    """
    df_added = df_sua
    ### amplitude
    waveforms = df_sua['avg_wf'].values
    df_added['amp_wf'] = [np.max(wf) - np.min(wf) for wf in waveforms]
    ### distance from peak to trough
    min_idcs = [np.argmin(wf) for wf in waveforms]
    df_added['width_wf'] = [np.abs(np.argmax(wf[min_idx:])+min_idx - np.argmin(wf)) for wf, min_idx in zip(waveforms,min_idcs)]
    return df_added

In [25]:
def aux_add_zscored_avg_waveform(df_sua):
    """
    From a dataframe with formated waveforms, saves one more column with each waveform zscored, 
    and another column with its zscored amplitude.
    """
    from scipy.stats import zscore
    
    waveforms = df_sua['avg_wf'].values
    df_added = df_sua
    df_added['avg_wf_zscored'] =  [zscore(wf.magnitude) for wf in waveforms]
    wfs_zsc = df_added['avg_wf_zscored'].values
    df_added['amp_wf_zscored'] = [np.max(wf) - np.min(wf) for wf in wfs_zsc]
    
    return df_added

In [11]:
def aux_add_width_classes(df_sua,width_intervals = WIDTH_INTERVALS):
    """
    Adding width class info, based on the measured width of a waveform (peak to the right max.).
    """
    names_widths = ['narrow','medium','wide']
    df_added = df_sua
    ### adding column with spike width classification into narrow, medium, wide
    aux_classes = []
    for idx in df_added.index:
        width_row = df_added.loc[idx]['width_wf']
        for i in range(len(width_intervals)):
            interval = width_intervals[i]
            if (width_row>=interval[0]) & (width_row<=interval[1]):
                aux_classes.append(names_widths[i])
    
    df_added['width_wf_class'] = np.array(aux_classes)
    return df_added

In [12]:
### classification into 6 classes
def aux_add_final_classes(df_sua,peak_borders=PEAK_BORDERS,final_classes=FINAL_CLASSES):
    df_sua['final_class'] = 'NO_CLASS'
    dict_cl_indices = {}
    for cl in final_classes:
        if cl=='DOWN_narrow_peak':
            aux_df = sua_df_RS[sua_df_RS['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_class']=='narrow']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[mask_peak]
            dict_cl_indices['DOWN_narrow_peak'] = aux_df.index
        elif cl=='DOWN_narrow_other':
            aux_df = sua_df_RS[sua_df_RS['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_class']=='narrow']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[~mask_peak]
            dict_cl_indices['DOWN_narrow_other'] = aux_df.index
        elif cl=='DOWN_medium':
            aux_df = sua_df_RS[sua_df_RS['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_class']=='medium']
            dict_cl_indices['DOWN_medium'] = aux_df.index
        elif cl=='DOWN_wide':
            aux_df = sua_df_RS[sua_df_RS['wf_direction']=='DOWN']
            aux_df = aux_df[aux_df['width_class']=='wide']
            dict_cl_indices['DOWN_wide'] = aux_df.index
        elif cl=='UP_peak':
            aux_df = sua_df_RS[sua_df_RS['wf_direction']=='UP']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[mask_peak]
            dict_cl_indices['UP_peak'] = aux_df.index
        elif cl=='UP_other':
            aux_df = sua_df_RS[sua_df_RS['wf_direction']=='UP']
            mask_peak = (aux_df['pref_phase_all_spikes']>=peak_borders[0]) & (aux_df['pref_phase_all_spikes']<=peak_borders[1])
            aux_df = aux_df[~mask_peak] 
            dict_cl_indices['UP_other'] = aux_df.index
        else:
            print('Undefined cell type.')
    for cl in final_classes:
        for i in dict_cl_indices[cl]:
            df_sua.loc[i]['final_class'] = cl
    return df_sua

## Dataframe calculation - the first part, computationaly expensive

In [13]:
if False:
    for monkey in MONKEY_LIST:
        print(monkey)
        for date in params['dates'][monkey]['RS']:
            print(date)
            prop_list = []
            for array in range(1,17): 
                print(array)
                try:
                    ### loading SUA spike trains and RB block
                    try:
                        spike_block = load_block(monkey,array,type_rec='RS',type_sig='spikes',date=date,data_folder=DATA_FOLDER) ### SUA
                        RB_block = load_block(monkey,array,type_rec='RS',type_sig='RB',date=date,data_folder=DATA_FOLDER)
                        num_cells = len(spike_block.segments[0].spiketrains)
                        start_t_spikes_ms = int(np.floor(np.float64(spike_block.segments[0].spiketrains[0].t_start.magnitude)*1000))
                        start_t_RB_ms = int(np.floor(np.float64(RB_block.segments[0].analogsignals[0].t_start.magnitude)*1000))
                        print(f'Start t RB: {start_t_spikes_ms}')
                        print(f'Start t spikes: {start_t_RB_ms}')
                        if start_t_spikes_ms!=start_t_RB_ms:
                            print('Spikes and ripples do not have the same start time.')
                    except:
                        print(f'Cannot read the spike file for date {date}, monkey {monkey}, array {array}.')
                    try:
                        df_OP = pd.read_csv(f'{DATA_FOLDER}/metadata/OP_maps_dataframes/{monkey}/OP_prop_OG_array{array}.csv')
                    except:
                        print(f'Cannot read OP maps for date {date}, monkey {monkey}, array {array}.')
                    for cell in range(num_cells):
                        spike_train = spike_block.segments[0].spiketrains[cell]
                        cell_name = spike_train.annotations['nix_name']
                        electrode_ID = spike_train.annotations['Electrode_ID']
                        
                        ### channel prop - additional info for a channel, such as OP, bad channel ID, array and area
                        channel_prop = {}
                        channel_prop['cell_name'] = cell_name
                        ### OP
                        try:
                            ch_OP = df_OP[df_OP['Electrode_ID']==electrode_ID]
                            if ch_OP['selectivity_01'].values[0]>0.2 and ch_OP['num_f0_high_jump'].values[0]<3:
                                channel_prop['pref_OP'] = ch_OP['pref_OP'].values[0]
                                channel_prop['selectivity_OP_01'] = ch_OP['selectivity_01'].values[0]
                            else:
                                channel_prop['pref_OP'] = np.nan
                                channel_prop['selectivity_OP_01'] = ch_OP['selectivity_01'].values[0]
                        except:
                            channel_prop['pref_OP'] = np.nan
                            channel_prop['selectivity_OP_01'] = np.nan
                        ### channel order
                        ch = aux_electrodeID_to_ch_order(monkey,date,electrode_ID,array,data_folder=DATA_FOLDER,type_rec='RS')
                        channel_prop['channel_order'] = ch
                        ### array
                        channel_prop['array'] = array
                        ### area
                        if monkey in ['N','F']:
                            name_area = 'Area'
                        else:
                            name_area = 'cortical_area'
                        ch_area = spike_train.annotations[name_area]
                        channel_prop['area'] = ch_area
                        ### order in the spike train
                        channel_prop['train_order'] = cell

                        rb_phase_arr = sig_block_to_arr(RB_block,'RB_phase')
                        rb_envelope_arr = sig_block_to_arr(RB_block,'RB_envelope_norm')
                        rb_env_phase_arr = sig_block_to_arr(RB_block,'RB_envelope_phase')

                        spike_arr = spike_block_to_arr(spike_block)

                        ### cutting out common times only for N and F
                        if monkey in ['N','F']:
                            rb_phase_arr = cut_abs_times(rb_phase_arr,start_t_RB_ms,monkey,rec_type='RS',date=date,params=params)
                            rb_envelope_arr = cut_abs_times(rb_envelope_arr,start_t_RB_ms,monkey,rec_type='RS',date=date,params=params)
                            rb_env_phase_arr = cut_abs_times(rb_env_phase_arr,start_t_RB_ms,monkey,rec_type='RS',date=date,params=params)
                            spike_arr = cut_abs_times(spike_arr,start_t_RB_ms,monkey,rec_type='RS',date=date,params=params)
                        
                        rb_phase = rb_phase_arr[ch,:]
                        rb_envelope = rb_envelope_arr[ch,:]
                        rb_env_phase = rb_env_phase_arr[ch,:]
                        spike_vector = spike_arr[cell,:]
                        
                        prop_dict = spike_train_prop_vec(spike_vector,rb_phase,rb_envelope,rb_env_phase,channel_prop=channel_prop) ### input already binned spikes
                        prop_list.append(prop_dict)
                except:
                    print(f'For array {array}, the SUA properties were not calculated.')
            df_prop = pd.DataFrame(prop_list)
            ensure_dir_exists(f'{DF_FOLDER}/sua_prop/')
            df_prop.to_pickle(f'{DF_FOLDER}/sua_prop/monkey{monkey}_all_arrays_date_{date}.pkl')


## Adding other properties and formating to the DF (computationaly easier)

In [34]:
for monkey in ['L']: #MONKEY_LIST: 
    print(monkey)
    all_RS_dates = params['dates'][monkey]['RS']
    for date in [all_RS_dates[0]]:
        print(date)
        with open(f'{DF_FOLDER}/sua_prop/monkey{monkey}_all_arrays_date_{date}.pkl', "rb") as file:
            df_sua = pickle.load(file)
        df_added = aux_add_waveform_prop(df_sua)
        df_added = aux_add_zscored_avg_waveform(df_added)
        df_added = df_added[df_added['channel_order']>-1] ### erasing not working arrays
        df_added = aux_add_width_classes(df_added,width_intervals=WIDTH_INTERVALS)
        df_added = aux_add_up_down_classes(df_added)
        df_added = aux_add_final_classes(df_added,peak_borders=PEAK_BORDERS,final_classes=FINAL_CLASSES)

        #### saving new dataframes with properties as pickle
        ensure_dir_exists(f'{DF_FOLDER}/sua_prop_all/')
        df_added.to_pickle(f'{DF_FOLDER}/sua_prop_all/monkey{monkey}_all_arrays_date_{date}.pkl')
        ### the copy warning is there only for the case of empty arrays, no worries about it

L
20170725


NameError: name 'FINAL_CLASSES' is not defined

In [18]:
df_sua.keys()

Index(['FR', 'CV_ISI', 'ISI', 'env_th_median', 'list_phases', 'list_env',
       'list_env_phases', 'list_phases_high_env', 'list_env_phases_high_env',
       'list_phases_low_env', 'list_env_phases_low_env', 'FR_high_env_median',
       'FR_low_env_median', 'FR_high_env_low_env_median_ratio',
       'pref_phase_all_spikes', 'norm_phase_sel_01_all_spikes',
       'pref_env_phase_all_spikes', 'norm_env_phase_sel_01_all_spikes',
       'env_th_perc_10', 'env_th_perc_20', 'env_th_perc_30', 'env_th_perc_40',
       'env_th_perc_50', 'env_th_perc_60', 'env_th_perc_70', 'env_th_perc_80',
       'env_th_perc_90', 'env_th_perc_95', 'cell_name', 'pref_OP',
       'selectivity_OP_01', 'channel_order', 'array', 'area', 'train_order'],
      dtype='object')

In [23]:
zscore(df_sua['avg_wf'][0].magnitude)

array([-0.05014339, -0.01445192,  0.02386775,  0.06520633,  0.10895354,
        0.15288113,  0.19597583,  0.24087043,  0.29160523,  0.34895566,
        0.40976766,  0.47132504,  0.5341738 ,  0.5987663 ,  0.6606145 ,
        0.7127482 ,  0.75613725,  0.8050764 ,  0.87291396,  0.9466244 ,
        0.9874277 ,  0.979754  ,  0.9792771 ,  1.0659401 ,  1.1943612 ,
        1.1032836 ,  0.46750894, -0.7732454 , -2.2416277 , -3.3451507 ,
       -3.7161076 , -3.4417903 , -2.890186  , -2.3654623 , -1.9425576 ,
       -1.5552638 , -1.1426339 , -0.69371617, -0.22113544,  0.2478717 ,
        0.66151655,  0.9617702 ,  1.1253848 ,  1.177942  ,  1.1656601 ,
        1.1214588 ,  1.0578392 ,  0.98036045,  0.89750934,  0.8179574 ,
        0.7446112 ,  0.6747624 ,  0.604675  ,  0.53231966,  0.457132  ,
        0.37971666,  0.30235577,  0.2283751 ,  0.16010645,  0.09762747,
        0.03965934, -0.01448944, -0.06378369, -0.10601135, -0.13959011,
       -0.16521944, -0.18592702, -0.20517863, -0.22424632, -0.24

In [30]:
df_sua['width_wf']

0      13
1       4
2       5
3      20
4       5
       ..
499     8
500     6
501    12
502     8
503     8
Name: width_wf, Length: 504, dtype: int64

In [33]:
WIDTH_INTERVALS

[[0, 7], [8, 12], [13, 90]]