# NAPE Calcium Imaging Event-Related Analysis

## What does this script do:

Loads activity traces from ROIs/sources across the whole session as well as timing of behavioral/manipulation events, then extracts a window of activity around each event, and finally generates numerous event-related plots resolving and averaging across trials and ROIs.

How to run this code
------------------------------------

In this jupyter notebook, First find the code block with the comment header called USER-DEFINED VARIABLES. Edit the variables according to your data and output preferences. Then just run all cells in order (shift + enter; or in the dropdown menu: Kernel->Resart & Run All). Please make sure you set the correct sampling rate (fs) in the user-defined variables block of code.

### Required Prerequisite/Input Files

All data should reside in a parent folder. This folder's name should be the name of the session and ideally be the same as the base name of the recording file.

1. ROI/source activity signals file ( `fparams['fname_signal']` ): This is either a npy or CSV file where rows are individual ROIs/sources and columns are samples, and the values are activity levels (ie. fluorescence or voltage for ephys). Note the CSV should not 
2. Event occurrence file (`fparams['fname_events']` ): This is either a pickle file or a CSV file. For the pickle file, it should contain a python dictionary where keys are event condition names and associated values are lists that contain event occurrence times (in samples). If using a CSV, the data should be in tidy format formated like what is shown here (note currently the first row, first and second columns should be "event" and "sample" respectively: https://github.com/zhounapeuw/NAPE_imaging_postprocess/raw/main/docs/_images/napeca_post_event_csv_format.png


Required Packages
-----------------
Python 3.7, seaborn, matplotlib, pandas, scikit-learn

Custom code requirements: utils

User-Defined Parameters 
----------

fname_signal : string
    
    Name of file that contains roi activity traces. Must include full file name with extension. Accepted file types: .npy, .csv. IMPORTANT: data dimensions should be rois (y) by samples/time (x)

fname_events : string

    Name of file that contains event occurrences. Must include full file name with extension. Accepted file types: .pkl, .csv. Pickle (pkl) files need to contain a dictionary where keys are the condition names and the values are lists containing samples/frames for each corresponding event. Csv's should have two columns (event condition, sample). The first row are the column names. Subsequent rows contain each trial's event condition and sample in tidy format. See example in sample_data folder for formatting, or this link: https://github.com/zhounapeuw/NAPE_imaging_postprocess/raw/main/docs/_images/napeca_post_event_csv_format.png

fdir : string 

    Root file directory containing the raw tif, tiff, h5 files. IMPORTANT Note: leave off the last backslash, and include the letter r in front of string (treats the contents as a raw string). For example: r'C:\Users\my_user\analyze_sessions'

fname : string

    Session name; by default this is the name of the parent folder that the data resides in, but can be changed by user to be any string. This fname variable is mainly used to name the saved output files.

flag_save_figs : boolean  

    Set as True to save figures as JPG and vectorized formats.  
    
fs : float

    Sampling rate of imaging data. It is imperative that this value is correct; otherwise the incorrect time windows will be pulled out for each event. If you suspect that 
    the sampling rate (fs) was not set correctly in the NAPECA preprocessing pipeline, go into the saved json file and edit the fs value.
    
selected_conditions : list of strings

    Specific conditions that the user wants to analyze; needs to be exactly the name of conditions in the events CSV or pickle file

trial_start_end : list of two entries  

    Entries can be ints or floats. The first entry is the time in seconds relative to the event/ttl onset for the start of the event analysis window (negative if before the event/ttl onset. The second entry is the time in seconds for the end of the event analysis window. For example if the desired analysis window is 5.5 seconds before event onset and 8 seconds after, `trial_start_end` would be [-5.5, 8].  
    
baseline_end : int/float  

    Time in seconds for the end of the baseline epoch. By default, the baseline epoch start time will be the first entry ot `trial_start_end`. This baseline epoch is used for calculating baseline normalization metrics.
    
event_dur : int/float  

    Time in seconds representing how long the behavioral event or stimulation, etc. lasts. A green line with the corresponding length will be plotted indicating stimulus duration

event_sort_analysis_win : list with two float entries

    Time window [a, b] in seconds during which some visualization calculations will apply to. For example, if the user sets flag_sort_rois to be True, ROIs in heatmaps will be sorted based on the mean activity in the time window between a and b. Similar principle holds for the time-averaged barplots.

opto_blank_frame : boolean

    if PMTs were blanked during stim, use detected stim times (from preprocessing) to set those frames to NaN

flag_npil_corr : boolean  

    Set as True if user would like to load in neuropil corrected data from the preprocessing pipeline. Must have a \*\_neuropil\_corrected_signal_* file in the directory. If set as False, just use the extracted_signal file.
    
flag_zscore : boolean  

    Set as True if analyzed data should be baseline z-scored on a trial level.  

flag_sort_rois : boolean

    Set as True to sort ROIs on the y axis of heatmaps. This works with `user_sort_method` and `roi_sort_cond` for specifying details of sorting.

user_sort_method : string
    
    Set to 'peak_time' to sort ROIs by peak time during event_sort_dur; 'max_value' to sort by max value
    
roi_sort_cond : string
    
    Condition to perform sorting on

flag_roi_trial_avg_errbar : boolean  

    Set as True to set standard error of mean shaded portions for line plots of trial-averaged activity.   

flag_trial_avg_errbar : boolean

    toggle standard error bars on trial-averaged data

interesting_rois : list of ints  

    All entries are indices for ROIs that will be marked in heatmaps by arrows.  

Optional Parameters (Only relevant if using batch_process)
-------------------

user_sort_method : string
    
    Takes the strings 'peak_time' or 'max_value'
    
    
roi_sort_cond : string
    for roi-resolved heatmaps, which condition to sort ROIs by
    
    Defaults to first condition available
    
Output
-------

event_rel_analysis : folder containing plots in jpg/png and vectorized formats

event_data_dict.pkl : pickle file containing a dictionary containing event-triggered data from each cell and organized by event condition. Raw and z-scored data are included as well.


In [None]:
import os
import numpy as np
import glob
import pickle
import json
import seaborn as sns
import matplotlib.ticker as ticker
import pandas as pd
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import matplotlib
#important for text to be detected when importing saved figures into illustrator
matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
plt.rcParams["font.family"] = "Arial"

import utils

In [None]:
# simple class to update limits as you go through iterations of data
# first call update_lims(first_lims)
# then update_lims.update(new_lims)
# update_lims.output() outputs lims
class update_lims:
    
    def __init__(self, lims):
        self.lims = lims
        
    
    def update(self, new_lims):
        if self.lims[0] > new_lims[0]:
            self.lims[0] = new_lims[0]
        
        if self.lims[1] < new_lims[1]:
            self.lims[1] = new_lims[1]

    def output(self):
        return self.lims
    
    
# find 2D subplot index based on a numerical incremented index (ie. idx=3 would be (2,1) for a 2x2 subplot figure)     
def subplot_loc(idx, num_rows, num_col):
    if num_rows == 1:
        subplot_index = idx
    else:
        subplot_index = np.unravel_index(idx, (num_rows, int(num_col))) # turn int index to a tuple of array coordinates
    return subplot_index

# calculate all the color limits for heatmaps; useful for locking color limits across different heatmap subplots   
def generate_clims(data_in, norm_type):
    # get min and max for all data across conditions 
    clims_out = [np.nanmin(data_in), np.nanmax(data_in)]
    if 'zscore' in norm_type: # if data are zscored, make limits symmetrical and centered at 0
        clims_max = np.max(np.abs(clims_out)) # then we take the higher of the two magnitudes
        clims_out = [-clims_max*0.5, clims_max*0.5] # and set it as the negative and positive limit for plotting
    return clims_out

def get_cmap(n, name='plasma'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)


def is_all_nans(vector):
    """
    checks if series or vector contains all nans; returns boolean. Used to identify and exclude all-nan rois
    """
    if isinstance(vector, pd.Series):
        vector = vector.values
    return np.isnan(vector).all()

# function for finding the index of the closest entry in an array to a provided value
def find_nearest_idx(array, value):

    if isinstance(array, pd.Series):
        idx = (np.abs(array - value)).idxmin()
        return idx, array.index.get_loc(idx), array[idx] # series index, 0-relative index, entry value
    else:
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return idx, array[idx]

    
# function to find closest sample when a time occurs in a time vector
tvec2samp = lambda tvec, time: np.argmin(np.abs(tvec - time))

In [None]:
"""
USER-DEFINED VARIABLES
"""

def define_params(method = 'single'):
    
    fparams = {}
    
    if method == 'single':
        
        fparams['fname_signal'] = 'VJ_OFCVTA_7_260_D6_neuropil_corrected_signals_15_50_beta_0.8.npy'   # 
        fparams['fname_events'] = 'event_times_VJ_OFCVTA_7_260_D6_trained.csv'
        # fdir signifies to the root path of the data. Currently, the abspath phrase points to sample data from the repo.
        # To specify a path that is on your local computer, use this string format: r'your_root_path', where you should copy/paste
        # your path between the single quotes (important to keep the r to render as a complete raw string). See example below:
        # r'C:\Users\stuberadmin\Documents\GitHub\NAPE_imaging_postprocess\napeca_post\sample_data' 
        fparams['fdir'] = os.path.abspath('./sample_data/VJ_OFCVTA_7_260_D6') 
        fparams['fname'] = os.path.split(fparams['fdir'])[1]
        fparams['flag_close_figs_after_save'] = True
        fparams['flag_save_figs'] = True
        
        # set the sampling rate
        fparams['fs'] = 5 # this gets overwritten by json fs variable (if it exists) that is saved in preprocessing
        json_fpath = os.path.join(fparams['fdir'], fparams['fname']+".json")
        if os.path.exists(json_fpath):
            json_data = utils.open_json(json_fpath)
            if 'fs' in json_data:
                fparams['fs'] = json_data['fs']

        # set to None if want to include all conditions from behav data; 
        # otherwise, set to list of conditions, eg. ['plus', 'minus']
        fparams['selected_conditions'] = None 
        
        # trial windowing and normalization
        fparams['trial_start_end'] = [-2, 8] # [start, end] times (in seconds) included in the visualization 
        fparams['flag_normalization'] = 'zscore' # options: 'zscore', None
        fparams['baseline_end'] = -0.2 # baseline epoch end time (in seconds) for performing baseline normalization
        fparams['event_dur'] = 2 # duration of stim/event in seconds; displays a line below main plot indicating event duration
        fparams['event_sort_analysis_win'] = [0, 5] # time window (in seconds)
        
        # session info
        fparams['opto_blank_frame'] = False # if PMTs were blanked during stim, set stim times to nan (instead of 0)

        # ROI sorting; if flag_sort_rois is set to True, ROIs are sorted by activity in the fparams['event_sort_analysis_win'] window
        fparams['flag_sort_rois'] = True
        if fparams['flag_sort_rois']:
            
            fparams['user_sort_method'] = 'max_value' # peak_time or max_value
            fparams['roi_sort_cond'] = 'plus' # for roi-resolved heatmaps, which condition to sort ROIs by
            
        # errorbar and saving figures
        fparams['flag_roi_trial_avg_errbar'] = True # toggle to show error bar on roi- and trial-averaged traces
        fparams['flag_trial_avg_errbar'] = True # toggle to show error bars on the trial-avg traces
        fparams['interesting_rois'] = [] #[ 0, 1, 2, 23, 22, 11, 9, 5, 6, 7, 3, 4, 8, 12, 14, 15, 16, 17] # [35, 30, 20, 4] #
    
    elif method == 'f2a': # if string is empty, load predefined list of files in files_to_analyze_event

        fparams = files_to_analyze_event.define_fparams()

    elif method == 'root_dir':
        
        pass

    with open(os.path.join(fparams['fdir'], 'event_analysis_fparam.json'), 'w') as fp:
        json.dump(fparams, fp)
    
    return fparams

fparams = define_params(method = 'single') # options are 'single', 'f2a', 'root_dir'

if 'zscore' in fparams['flag_normalization']:
    data_trial_resolved_key = 'zdata'
    data_trial_avg_key = 'ztrial_avg_data'
    cmap_ = None
    ylabel = 'Z-score Activity'
else:
    data_trial_resolved_key = 'data'
    data_trial_avg_key = 'trial_avg_data'
    cmap_ = 'inferno'
    ylabel = 'Activity'

In [None]:
# declare paths
def define_paths(dict_vars, fparams):
    dict_vars['signals_fpath'] = os.path.join(fparams['fdir'], fparams['fname_signal'])
    dict_vars['events_file_path'] = os.path.join(fparams['fdir'], fparams['fname_events'])

    dict_vars['save_dir']= os.path.join(fparams['fdir'], 'event_rel_analysis')
    utils.check_exist_dir(dict_vars['save_dir']); # make the save directory

    return dict_vars

def make_timing_info(dict_vars, fparams):
    ### create variables that reference samples and times for slicing and plotting the data
    trial_start_end_sec = np.array(fparams['trial_start_end']) # trial windowing in seconds relative to ttl-onset/trial-onset, in seconds
    baseline_start_end_sec = np.array([trial_start_end_sec[0], fparams['baseline_end']])

    # convert times to samples and get sample vector for the trial 
    dict_analysis_vars['trial_begEnd_samp'] = trial_start_end_sec*fparams['fs'] # turn trial start/end times to samples
    trial_svec = np.arange(dict_vars['trial_begEnd_samp'][0], dict_vars['trial_begEnd_samp'][1])
    # and for baseline period
    dict_vars['baseline_begEnd_samp'] = baseline_start_end_sec*fparams['fs']
    dict_vars['baseline_svec'] = (np.arange(dict_vars['baseline_begEnd_samp'][0], dict_vars['baseline_begEnd_samp'][1]+1, 1) - dict_vars['baseline_begEnd_samp'][0]).astype('int')

    # calculate time vector for plot x axes
    num_samples_trial = len( trial_svec )
    dict_vars['tvec'] = np.round(np.linspace(trial_start_end_sec[0], trial_start_end_sec[1], num_samples_trial+1), 2)

    # find samples and calculations for time 0 for plotting
    dict_vars['t0_sample'] = utils.get_tvec_sample(dict_vars['tvec'], 0) # grabs the sample index of a given time from a vector of times
    dict_vars['event_end_sample'] = int(np.round(dict_vars['t0_sample']+fparams['event_dur']*fparams['fs']))
    # event_bound_ratio : fraction of total samples for event start and end. Used for mapping which samples to plot green line indicating event duration
    dict_vars['event_bound_ratio'] = [(dict_vars['t0_sample'])/num_samples_trial , dict_vars['event_end_sample']/num_samples_trial] # fraction of total samples for event start and end; only used for plotting line indicating event duration

    return dict_vars

def load_signals(dict_vars, fparams):
    # requires define_paths method to be run first 
    signals = utils.load_signals(dict_vars['signals_fpath'])

    dict_vars['num_rois'] = signals.shape[0]
    dict_vars['all_nan_rois'] = np.where(np.apply_along_axis(is_all_nans, 1, signals)) # find rois with activity as all nans

    # if opto stim frames were detected in preprocessing, set these frames to be NaN (b/c of stim artifact)
    if fparams['opto_blank_frame']:
        try:
            glob_stim_files = glob.glob(os.path.join(fparams['fdir'], "{}*_stimmed_frames.pkl".format(fparams['fname'])))
            stim_frames = pickle.load( open( glob_stim_files[0], "rb" ) )
            signals[:,stim_frames['samples']] = None # blank out stimmed frames
            flag_stim = True
            print('Detected stim data; replaced stim samples with NaNs')
        except:
            flag_stim = False
            print('Note: No stim preprocessed meta data detected.')

    dict_vars['signals'] = signals

    return dict_vars

def load_behav(dict_vars, fparams):
    ### load behavioral data and trial info
    # requires define_paths method to be run first 

    glob_event_files = glob.glob(dict_vars['events_file_path']) # look for a file in specified directory
    if not glob_event_files:
        print('{} not detected. Please check if path is correct.'.format(dict_vars['events_file_path']))
    if 'csv' in glob_event_files[0]:
        event_times = utils.df_to_dict(glob_event_files[0])
    elif 'pkl' in glob_event_files[0]:
        event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
    dict_vars['event_frames'] = utils.dict_samples_to_time(event_times, fparams['fs'])

    # identify conditions to analyze
    all_conditions = dict_vars['event_frames'].keys()
    conditions = [ condition for condition in all_conditions if len(dict_vars['event_frames'][condition]) > 0 ] # keep conditions that have events

    conditions.sort()
    if fparams['selected_conditions']:
        conditions = fparams['selected_conditions']
    dict_vars['conditions'] = conditions
    dict_vars['cmap_lines'] = get_cmap(len(conditions)) # colors for plotting lines for each condition

    return dict_vars

def set_fontsizes(dict_vars):
    # declare some fixed constant variables
    dict_vars['axis_label_size'] = 15
    dict_vars['tick_font_size'] = 14 
    return dict_vars

In [None]:
def subplot_trial_heatmap(fig, ax, data_in, tvec, event_bound_ratio, clims, title, 
                          subplot_index, cmap_='inferno', save_fig = False, axis_label_size=15, tick_font_size=14):
    
    """
    
    data_in : np.array with dimensions trials x samples
        
    
    Prep information about specific condition (each loop) and plot heatmap
        1) x/y label tick values
        2) x/y labels
        3) grab specific condition's data
        4) plot data (using utils function)
        5) plot meta data lines (eg. 0-time line, event duration line)

    event_bound_ratio : list of two entries
        where entries are fraction of total samples for event start and end. Used for mapping which samples to plot green line 
        indicating event duration
    """

    # set imshow extent to replace x and y axis ticks/labels
    plot_extent = [tvec[0], tvec[-1], data_in.shape[0], 0] # [x min, x max, y min, ymax]

    # prep labels; plot x and y labels for first subplot
    if subplot_index == (0, 0) or subplot_index == 0 :
        ax[subplot_index].set_ylabel('Trial', fontsize=axis_label_size)
        ax[subplot_index].set_xlabel('Time [s]', fontsize=axis_label_size);
    ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)

    # prep the data
    to_plot = np.squeeze(data_in) 
    if data_in.shape[0] == 1: # accomodates single trial data
        to_plot = to_plot[np.newaxis, :]

    # plot the data
    im = utils.subplot_heatmap(ax[subplot_index], title, to_plot, cmap=cmap_, clims=clims, extent_=plot_extent)

    # add meta data lines
    ax[subplot_index].axvline(0, color='0.5', alpha=1) # plot vertical line for time zero
    # plots green horizontal line indicating event duration
    ax[subplot_index].annotate('', xy=(event_bound_ratio[0], -0.01), xycoords='axes fraction', 
                               xytext=(event_bound_ratio[1], -0.01), 
                               arrowprops=dict(arrowstyle="-", color='g'))

    cbar = fig.colorbar(im, ax = ax[subplot_index], shrink = 0.5)
    cbar.ax.set_ylabel(ylabel, fontsize = axis_label_size)

    
def plot_trial_heatmap_traces(dict_vars, fparams, data_dict):
    
    """
    Generates a figure per ROI/channel
    Each figure consists of n heatmap panels of trial activity centered on an event where n is the number of event types/conditions.
    The last panel of the figure contains trial-averaged traces for all conditions
    """
    
    num_subplots = len(dict_vars['conditions']) + 1 # plus one for trial-avg traces
    n_columns = np.min([num_subplots, 4.0])
    n_rows = int(np.ceil(num_subplots/n_columns))

    for iROI in range(dict_vars['num_rois']):

        # calculate color limits. This is outside of heatmap function b/c want lims across conditions
        # loop through each condition's data and flatten before concatenating values
        roi_clims = generate_clims(np.concatenate([data_dict[cond][data_trial_resolved_key][:, iROI, :].flatten() for cond in dict_vars['conditions']]), 
                                   fparams['flag_normalization'])

        fig, ax = plt.subplots(nrows=n_rows, ncols=int(n_columns), 
                               figsize=(n_columns*4, n_rows*3),
                               constrained_layout=True)

        ### Plot heatmaps for each condition
        for idx_cond, cond in enumerate(dict_vars['conditions']):

            subplot_index = subplot_loc(idx_cond, n_rows, n_columns) # determine subplot location index
            data_to_plot = data_dict[cond][data_trial_resolved_key][:, iROI, :]
            title = 'ROI {}; {}'.format(str(iROI), cond)

            subplot_trial_heatmap(fig, ax, data_to_plot, dict_vars['tvec'], dict_vars['event_bound_ratio'], 
                                  roi_clims, title, subplot_index, cmap_, save_fig=False, 
                                  axis_label_size=dict_vars['axis_label_size'], tick_font_size=dict_vars['tick_font_size'])

        ### plot last subplot of trial-avg traces

        # determine subplot location index
        subplot_index = subplot_loc(num_subplots-1, n_rows, n_columns)

        for cond in dict_analysis_vars['conditions']:

            # prep data to plot
            num_trials = data_dict[cond]['num_trials']
            to_plot = np.nanmean(data_dict[cond][data_trial_resolved_key][:,iROI,:], axis=0)
            to_plot_err = np.nanstd(data_dict[cond][data_trial_resolved_key][:,iROI,:], axis=0)/np.sqrt(num_trials)

            # plot trace
            ax[subplot_index].plot(dict_vars['tvec'], to_plot)
            if fparams['opto_blank_frame']: 
                ax[subplot_index].plot(tvec[dict_vars['t0_sample']:dict_vars['event_end_sample']], 
                                       to_plot[dict_vars['t0_sample']:dict_vars['event_end_sample']], marker='.', color='g')
            # plot shaded error
            if fparams['flag_trial_avg_errbar']:
                ax[subplot_index].fill_between(dict_vars['tvec'], to_plot - to_plot_err, to_plot + to_plot_err,
                             alpha=0.5) # this plots the shaded error bar

        # plot x, y labels, and legend
        ax[subplot_index].set_ylabel(ylabel, fontsize=dict_vars['axis_label_size'])
        ax[subplot_index].set_xlabel('Time [s]', fontsize=dict_vars['axis_label_size'])
        ax[subplot_index].set_title('ROI # {}; Trial-avg'.format(str(iROI)), fontsize=dict_vars['axis_label_size'])
        ax[subplot_index].legend(dict_vars['conditions'])
        ax[subplot_index].autoscale(enable=True, axis='both', tight=True)
        ax[subplot_index].axvline(0, color='0.5', alpha=0.65) # plot vertical line for time zero
        ax[subplot_index].annotate('', xy=(dict_vars['event_bound_ratio'][0], -0.01), xycoords='axes fraction', 
                                       xytext=(dict_vars['event_bound_ratio'][1], -0.01), 
                                       arrowprops=dict(arrowstyle="-", color='g'))
        ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = dict_vars['tick_font_size'])

        for a in ax.flat[num_subplots:]:
            a.axis('off')

        if fparams['flag_save_figs']:
            fig.savefig( os.path.join(dict_vars['save_dir'],'roi_{}_activity.png'.format(str(iROI))) ); 
            fig.savefig( os.path.join(dict_vars['save_dir'],'roi_{}_activity.pdf'.format(str(iROI))) );

        if fparams['flag_close_figs_after_save']:
            plt.close(fig)
            
            
def sort_heatmap_peaks(data, tvec, sort_epoch_start_time, sort_epoch_end_time, sort_method = 'peak_time'):
    # function to sort ROIs based on activity in certain epoch
    
    # find start/end samples for epoch
    sort_epoch_start_samp = tvec2samp(tvec, sort_epoch_start_time)
    sort_epoch_end_samp = tvec2samp(tvec, sort_epoch_end_time)
    
    if sort_method == 'peak_time':
        epoch_peak_samp = np.argmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(epoch_peak_samp)
    elif sort_method == 'max_value':
 
        time_max = np.nanmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(time_max)[::-1]

    return final_sorting



def sort_rois(dict_vars, fparams, data_trial_avg_key):

    if fparams['flag_sort_rois']: # if flag is true, sort rois by average activity in specific window
        if not fparams['roi_sort_cond']: # if no condition to sort by specified, use first condition
            fparams['roi_sort_cond'] = data_dict.keys()[0]
        if not fparams['roi_sort_cond'] in data_dict.keys():
            sorted_roi_order = range(dict_vars['num_rois'])
            interesting_rois = fparams['interesting_rois']
            print('Specified condition to sort by doesn\'t exist! ROIs are in default sorting.')
        else: # if sorting condition is valid, sort based on activity
            # returns new order of rois sorted using the data and method supplied in the specified window
            sorted_roi_order = sort_heatmap_peaks(data_dict[fparams['roi_sort_cond']][data_trial_avg_key], dict_vars['tvec'], 
                               sort_epoch_start_time=0, 
                               sort_epoch_end_time = fparams['trial_start_end'][-1], 
                               sort_method = fparams['user_sort_method'])
            # finds corresponding interesting roi (roi's to mark with an arrow) order after sorting
            interesting_rois = np.in1d(sorted_roi_order, fparams['interesting_rois']).nonzero()[0] 
    else:
        sorted_roi_order = range(dict_vars['num_rois'])
        interesting_rois = fparams['interesting_rois']

    # get rid of ROIs that have all NaN's in data
    if not dict_vars['all_nan_rois'][0].size == 0:
        set_diff_keep_order = lambda main_list, remove_list : [i for i in main_list if i not in remove_list]
        sorted_roi_order = set_diff_keep_order(dict_vars['sorted_roi_order'], dict_vars['all_nan_rois'])
        interesting_rois = [i for i in fparams['interesting_rois'] if i not in dict_vars['all_nan_rois']]
        print(f'Removed {str(len(remove_list))} ROIs with NaN data')
    
    # save sorted roi order
    roi_order_path = os.path.join(fparams['fdir'], fparams['fname'] + '_roi_order.pkl')
    with open(roi_order_path, 'wb') as handle:
         pickle.dump(sorted_roi_order, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    dict_vars['sorted_roi_order'] = sorted_roi_order
    dict_vars['interesting_rois'] = interesting_rois            
    
    return dict_vars


def plot_trial_avg_heatmap(dict_vars, data_in, cmap, clims, sorted_roi_order=None, 
                           rois_oi=None, save_fig=False, axis_label_size=15, tick_font_size=14):
    
    """
    Technically doesn't need to remove all_nan_rois b/c of nanmean calculations
    """
    
    num_subplots = len(dict_vars['conditions'])
    n_columns = np.min([num_subplots, 3.0])
    n_rows = int(np.ceil(num_subplots/n_columns))

    # set imshow extent to replace x and y axis ticks/labels (replace samples with time)
    plot_extent = [dict_vars['tvec'][0], dict_vars['tvec'][-1], dict_vars['num_rois'], 0 ]

    fig, ax = plt.subplots(nrows=n_rows, ncols=int(n_columns), figsize = (n_columns*5, n_rows*4))
    if not isinstance(ax,np.ndarray): # this is here to make the code below compatible with indexing a single subplot object
        ax = [ax]

    for idx, cond in enumerate(dict_vars['conditions']):

        # determine subplot location index
        if n_rows == 1:
            subplot_index = idx
        else:
            subplot_index = np.unravel_index(idx, (n_rows, int(n_columns))) # turn int index to a tuple of array coordinates

        # prep labels; plot x and y labels for first subplot
        if subplot_index == (0, 0) or subplot_index == 0 :
            ax[subplot_index].set_ylabel('ROI #', fontsize=axis_label_size)
            ax[subplot_index].set_xlabel('Time [s]', fontsize=axis_label_size);
        ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
        
        # plot the data
        if sorted_roi_order is not None:
            roi_order = sorted_roi_order
        else:
            roi_order = slice(0, dict_vars['num_rois'])
        to_plot = data_in[cond][data_trial_avg_key][roi_order,:] # 

        im = utils.subplot_heatmap(ax[subplot_index], cond, to_plot, cmap=cmap_, clims=clims, extent_=plot_extent)
        ax[subplot_index].axvline(0, color='k', alpha=0.3) # plot vertical line for time zero
        ax[subplot_index].annotate('', xy=(dict_vars['event_bound_ratio'][0], -0.01), xycoords='axes fraction', 
                                       xytext=(dict_vars['event_bound_ratio'][1], -0.01), 
                                       arrowprops=dict(arrowstyle="-", color='g'))
        if rois_oi is not None:
            for ROI_OI in rois_oi:
                ax[subplot_index].annotate('', xy=(1.005, 1-(ROI_OI/dict_vars['num_rois'])-0.015), xycoords='axes fraction', 
                                           xytext=(1.06, 1-(ROI_OI/dict_vars['num_rois'])-0.015), 
                                           arrowprops=dict(arrowstyle="->", color='k'))
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    cbar = fig.colorbar(im, ax = ax, shrink = 0.7)
    cbar.ax.set_ylabel(ylabel, fontsize=13)
    
    # hide empty subplot
    for a in ax.flat[num_subplots:]:
        a.axis('off')
    
    if save_fig:
        fig.savefig(os.path.join(dict_vars['save_dir'],'trial_avg_heatmap.png')); 
        fig.savefig(os.path.join(dict_vars['save_dir'],'trial_avg_heatmap.pdf'));
        
        
def plot_trial_roi_avg_traces(dict_vars, fparams, data_trial_resolved_key, axis_label_size=15, tick_font_size=14):

    line_shades = []
    fig, axs = plt.subplots(1,1, figsize = (10,6))
    for idx, cond in enumerate(dict_vars['conditions']):
        line_color = dict_vars['cmap_lines'](idx)
        # first trial avg the data
        trial_avg = np.nanmean(data_dict[cond][data_trial_resolved_key], axis=0)

        # z-score trial-avg data for each respective ROI
        # apply zscore function to each row of data
        app_axis = 1 
        zscore_trial_avg = np.apply_along_axis(utils.zscore_, app_axis, trial_avg, dict_vars['baseline_svec'])

        # take avg/std across ROIs
        zscore_roi_trial_avg = np.nanmean(zscore_trial_avg, axis=0)
        zscore_roi_trial_std = np.nanstd(zscore_trial_avg, axis=0)

        to_plot = np.squeeze(zscore_roi_trial_avg)
        to_plot_err = np.squeeze(zscore_roi_trial_std)/np.sqrt(dict_vars['num_rois'])

        axs.plot(dict_vars['tvec'], to_plot, color=line_color)
        if fparams['opto_blank_frame']:
            line = axs.plot(dict_vars['tvec'][dict_vars['t0_sample']:dict_vars['event_end_sample']], 
                                 to_plot[dict_vars['t0_sample']:dict_vars['event_end_sample']], marker='.', color=line_color)
        else:
            line = axs.plot(dict_vars['tvec'][dict_vars['t0_sample']:dict_vars['event_end_sample']], 
                                 to_plot[dict_vars['t0_sample']:dict_vars['event_end_sample']], color=line_color)

        if fparams['flag_roi_trial_avg_errbar']:
            shade = axs.fill_between(dict_vars['tvec'], to_plot - to_plot_err, to_plot + to_plot_err, color = line_color,
                         alpha=0.2) # this plots the shaded error bar
            line_shades.append((line[0],shade))

    axs.set_ylabel(ylabel, fontsize=axis_label_size)
    axs.set_xlabel('Time [s]', fontsize=axis_label_size);
    axs.legend(dict_vars['conditions']);
    axs.legend(line_shades, dict_vars['conditions'], fontsize=15)
    axs.axvline(0, color='0.5', alpha=0.65) # plot vertical line for time zero
    axs.annotate('', xy=(dict_vars['event_bound_ratio'][0], -0.01), xycoords='axes fraction', 
                                   xytext=(dict_vars['event_bound_ratio'][1], -0.01), 
                                   arrowprops=dict(arrowstyle="-", color='g'))
    axs.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size+3)
    axs.autoscale(enable=True, axis='both', tight=True)

    #axs.set_ylim([-1.5, 10])

    if fparams['flag_save_figs']:
            fig.savefig(os.path.join(dict_vars['save_dir'],'roi_trial_avg_trace.png')); fig.savefig(os.path.join(dict_vars['save_dir'],'roi_trial_avg_trace.pdf'));

            
def stats_trial_roi_avg_traces(dict_vars, fparams, data_trial_resolved_key, axis_label_size=15, tick_font_size=14):

    analysis_window = fparams['event_sort_analysis_win']
    analysis_win_samps = [ find_nearest_idx(dict_vars['tvec'], time)[0] for time in analysis_window ]

    to_plot = []
    to_plot_err = []

    fig, axs = plt.subplots(1,1, figsize = (5,5))
    for idx, cond in enumerate(dict_vars['conditions']):
        line_color = dict_vars['cmap_lines'](idx)
        # first trial avg the data
        trial_avg = np.nanmean(data_dict[cond][data_trial_resolved_key], axis=0)

        # z-score trial-avg data for each respective ROI
        # apply zscore function to each row of data
        apply_axis = 1 
        zscore_trial_avg = np.apply_along_axis(utils.zscore_, apply_axis, trial_avg, dict_vars['baseline_svec'])

        # take avg across time
        zscore_trial_time_avg = np.nanmean(zscore_trial_avg[:,analysis_win_samps[0]:analysis_win_samps[1],:], axis=1)

        # take avg/std across ROIs
        zscore_roi_trial_time_avg = np.nanmean(zscore_trial_time_avg, axis=0)
        zscore_roi_trial_time_std = np.nanstd(zscore_trial_time_avg, axis=0)

        to_plot.append(zscore_roi_trial_time_avg[0])
        to_plot_err.append(zscore_roi_trial_time_std[0]/np.sqrt(len(zscore_trial_time_avg)))

    barlist = axs.bar(dict_vars['conditions'], to_plot, yerr=to_plot_err, align='center', alpha=0.5, ecolor='black', capsize=10 )
    for idx in range(len(dict_vars['conditions'])):
        barlist[idx].set_color(dict_vars['cmap_lines'](idx))
    axs.set_ylabel('Normalized Fluorescence', fontsize=axis_label_size)
    axs.set_title('ROI-, Trial-, Time-averaged Quant', fontsize=axis_label_size)
    axs.yaxis.grid(True)
    axs.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
    axs.tick_params(axis = 'x', which = 'major', rotation = 45)
    # Save the figure and show
    plt.tight_layout()

    if fparams['flag_save_figs']:
        fig.savefig(os.path.join(dict_vars['save_dir'],'roi_trial_time_avg_bar.png')); 

        fig.savefig(os.path.join(dict_vars['save_dir'],'roi_trial_time_avg_bar.pdf'));

In [None]:
if __name__ == "__main__": 
    """
    Declare and compute meta information, raw, data, and plotting info
    """
    print('Prepping meta info')
    dict_analysis_vars = {}
    ### create variables that reference samples and times for slicing and plotting the data
    dict_analysis_vars = define_paths(dict_analysis_vars, fparams)
    dict_analysis_vars = make_timing_info(dict_analysis_vars, fparams)

    ### load behavioral data and trial info
    dict_analysis_vars = load_signals(dict_analysis_vars, fparams)
    dict_analysis_vars = load_behav(dict_analysis_vars, fparams)

    dict_analysis_vars = set_fontsizes(dict_analysis_vars)

    """
    MAIN data processing function to extract event-centered data

    extract and save trial data, 
    saved data are in the event_rel_analysis subfolder, a pickle file that contains the extracted trial data
    """
    print('Extracting event-related data')
    data_dict = utils.extract_trial_data(dict_analysis_vars, dict_analysis_vars['signals'], dict_analysis_vars['event_frames'], 
                                         dict_analysis_vars['conditions'], save_dir=dict_analysis_vars['save_dir'])

    """
    ## Plot trial-resolved heatmap for each ROI
    """
    print('Plotting trial heatmaps for each ROI')
    plot_trial_heatmap_traces(dict_analysis_vars, fparams, data_dict)

    """
    ## Plot trial-averaged heatmap of all ROIs
    """
    print('Plotting trial-averaged heatmaps all ROI')
    dict_analysis_vars = sort_rois(dict_analysis_vars, fparams, data_trial_avg_key) # sort rois based on activity

    plot_trial_avg_heatmap(dict_analysis_vars, data_dict, cmap_,
                           clims = generate_clims(np.concatenate([data_dict[cond][data_trial_avg_key].flatten() for cond in dict_analysis_vars['conditions']]), 
                                                  fparams['flag_normalization']),
                           sorted_roi_order = dict_analysis_vars['sorted_roi_order'], 
                           rois_oi = dict_analysis_vars['interesting_rois'], 
                           save_fig = fparams['flag_save_figs'],
                           axis_label_size=dict_analysis_vars['axis_label_size'], tick_font_size=dict_analysis_vars['tick_font_size'])

    """
    ## Plot trial- and ROI-averaged traces
    """
    print('Plotting trial- and ROI-averaged traces')
    plot_trial_roi_avg_traces(dict_analysis_vars, fparams, data_trial_resolved_key,
                              axis_label_size=dict_analysis_vars['axis_label_size'], 
                              tick_font_size=dict_analysis_vars['tick_font_size'])

    """
    ### Quantification of roi-, trial-, time-averaged data
    """
    print('Plotting quantification of trial- and ROI-averaged data')
    stats_trial_roi_avg_traces(dict_analysis_vars, fparams, data_trial_resolved_key,
                               axis_label_size=dict_analysis_vars['axis_label_size'], 
                               tick_font_size=dict_analysis_vars['tick_font_size'])