# NAPE Calcium Imaging Event-Related Analysis

## What does this script do:

Loads activity traces from ROIs/sources across the whole session as well as timing of behavioral/manipulation evetns, then extracts a window of activity around each event, and finally generates numerous event-related plots resolving and averaging across trials and ROIs.

How to run this code
------------------------------------

In this jupyter notebook, First find the code block with the comment header called USER-DEFINED VARIABLES. Edit the variables according to your data and output preferences. Then just run all cells in order (shift + enter; or in the dropdown menu: Kernel->Resart & Run All). Please make sure you set the correct sampling rate (fs) in the user-defined variables block of code.

### Required Prerequisite/Input Files

All data should reside in a parent folder. This folder's name should be the name of the session and ideally be the same as the base name of the recording file.

1. ROI/source activity signals file ( `fparams['fname_signal']` ): This is either a npy or CSV file where rows are individual ROIs/sources and columns are samples, and the values are activity levels (ie. fluorescence or voltage for ephys). Note the CSV should not 
2. Event occurrence file (`fparams['fname_events']` ): This is either a pickle file or a CSV file. For the pickle file, it should contain a python dictionary where keys are event condition names and associated values are lists that contain event occurrence times (in samples). If using a CSV, the data should be in tidy format formated like what is shown here (note currently the first row, first and second columns should be "event" and "sample" respectively: https://github.com/zhounapeuw/NAPE_imaging_postprocess/raw/main/docs/_images/napeca_post_event_csv_format.png


Required Packages
-----------------
Python 3.7, seaborn, matplotlib, pandas, scikit-learn

Custom code requirements: utils

User-Defined Parameters 
----------

fname_signal : string
    
    Name of file that contains roi activity traces. Must include full file name with extension. Accepted file types: .npy, .csv. IMPORTANT: data dimensions should be rois (y) by samples/time (x)

fname_events : string

    Name of file that contains event occurrences. Must include full file name with extension. Accepted file types: .pkl, .csv. Pickle (pkl) files need to contain a dictionary where keys are the condition names and the values are lists containing samples/frames for each corresponding event. Csv's should have two columns (event condition, sample). The first row are the column names. Subsequent rows contain each trial's event condition and sample in tidy format. See example in sample_data folder for formatting, or this link: https://github.com/zhounapeuw/NAPE_imaging_postprocess/raw/main/docs/_images/napeca_post_event_csv_format.png

fdir : string 

    Root file directory containing the raw tif, tiff, h5 files. IMPORTANT Note: leave off the last backslash, and include the letter r in front of string (treats the contents as a raw string). For example: r'C:\Users\my_user\analyze_sessions'

fname : string

    Session name; by default this is the name of the parent folder that the data resides in, but can be changed by user to be any string. This fname variable is mainly used to name the saved output files.

flag_save_figs : boolean  

    Set as True to save figures as JPG and vectorized formats.  
    
fs : float

    Sampling rate of imaging data. It is imperative that this value is correct; otherwise the incorrect time windows will be pulled out for each event. If you suspect that 
    the sampling rate (fs) was not set correctly in the NAPECA preprocessing pipeline, go into the saved json file and edit the fs value.
    
selected_conditions : list of strings

    Specific conditions that the user wants to analyze; needs to be exactly the name of conditions in the events CSV or pickle file

trial_start_end : list of two entries  

    Entries can be ints or floats. The first entry is the time in seconds relative to the event/ttl onset for the start of the event analysis window (negative if before the event/ttl onset. The second entry is the time in seconds for the end of the event analysis window. For example if the desired analysis window is 5.5 seconds before event onset and 8 seconds after, `trial_start_end` would be [-5.5, 8].  
    
baseline_end : int/float  

    Time in seconds for the end of the baseline epoch. By default, the baseline epoch start time will be the first entry ot `trial_start_end`. This baseline epoch is used for calculating baseline normalization metrics.
    
event_dur : int/float  

    Time in seconds representing how long the behavioral event or stimulation, etc. lasts. A green line with the corresponding length will be plotted indicating stimulus duration

event_sort_analysis_win : list with two float entries

    Time window [a, b] in seconds during which some visualization calculations will apply to. For example, if the user sets flag_sort_rois to be True, ROIs in heatmaps will be sorted based on the mean activity in the time window between a and b. Similar principle holds for the time-averaged barplots.

opto_blank_frame : boolean

    if PMTs were blanked during stim, use detected stim times (from preprocessing) to set those frames to NaN

flag_npil_corr : boolean  

    Set as True if user would like to load in neuropil corrected data from the preprocessing pipeline. Must have a \*\_neuropil\_corrected_signal_* file in the directory. If set as False, just use the extracted_signal file.
    
flag_zscore : boolean  

    Set as True if analyzed data should be baseline z-scored on a trial level.  

flag_sort_rois : boolean

    Set as True to sort ROIs on the y axis of heatmaps. This works with `user_sort_method` and `roi_sort_cond` for specifying details of sorting.

user_sort_method : string
    
    Set to 'peak_time' to sort ROIs by peak time during event_sort_dur; 'max_value' to sort by max value
    
roi_sort_cond : string
    
    Condition to perform sorting on

flag_roi_trial_avg_errbar : boolean  

    Set as True to set standard error of mean shaded portions for line plots of trial-averaged activity.   

flag_trial_avg_errbar : boolean

    toggle standard error bars on trial-averaged data

interesting_rois : list of ints  

    All entries are indices for ROIs that will be marked in heatmaps by arrows.  

Optional Parameters (Only relevant if using batch_process)
-------------------

user_sort_method : string
    
    Takes the strings 'peak_time' or 'max_value'
    
    
roi_sort_cond : string
    for roi-resolved heatmaps, which condition to sort ROIs by
    
    Defaults to first condition available
    
Output
-------

event_rel_analysis : folder containing plots in jpg/png and vectorized formats

events_dict.pkl : pickle file containing a dictionary containing event-triggered data from each cell and organized by event condition. Raw and z-scored data are included as well.


In [None]:
import os
import numpy as np
import glob
import pickle
import json
import seaborn as sns
import matplotlib.ticker as ticker
import pandas as pd
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import matplotlib
#important for text to be detected when importing saved figures into illustrator
matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
plt.rcParams["font.family"] = "Arial"

import utils

In [None]:
# simple class to update limits as you go through iterations of data
# first call update_lims(first_lims)
# then update_lims.update(new_lims)
# update_lims.output() outputs lims
class update_lims:
    
    def __init__(self, lims):
        self.lims = lims
        
    
    def update(self, new_lims):
        if self.lims[0] > new_lims[0]:
            self.lims[0] = new_lims[0]
        
        if self.lims[1] < new_lims[1]:
            self.lims[1] = new_lims[1]

    def output(self):
        return self.lims
    
    
# find 2D subplot index based on a numerical incremented index (ie. idx=3 would be (2,1) for a 2x2 subplot figure)     
def subplot_loc(idx, num_rows, num_col):
    if n_rows == 1:
        subplot_index = idx
    else:
        subplot_index = np.unravel_index(idx, (n_rows, int(n_columns))) # turn int index to a tuple of array coordinates
    return subplot_index


def get_cmap(n, name='plasma'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)


def is_all_nans(vector):
    """
    checks if series or vector contains all nans; returns boolean. Used to identify and exclude all-nan rois
    """
    if isinstance(vector, pd.Series):
        vector = vector.values
    return np.isnan(vector).all()

# declare some fixed constant variables
axis_label_size = 15
tick_font_size = 14 


In [None]:
"""
USER-DEFINED VARIABLES
"""

def define_params(method = 'single'):
    
    fparams = {}
    
    if method == 'single':
        
        fparams['fname_signal'] = 'VJ_OFCVTA_7_260_D6_neuropil_corrected_signals_15_50_beta_0.8.npy'   # 
        fparams['fname_events'] = 'event_times_VJ_OFCVTA_7_260_D6_trained.csv'
        # for your own data, modify and use this phrase below for fparams['fname']: 
        # r'C:\Users\stuberadmin\Documents\GitHub\NAPE_imaging_postprocess\napeca_post\sample_data' 
        # replace the contents inside the apostrophes with the path to your data; make sure the r comes before the apostrophe
        fparams['fdir'] = os.path.abspath('./sample_data/VJ_OFCVTA_7_260_D6') 
        fparams['fname'] = os.path.split(fparams['fdir'])[1]
        fparams['flag_save_figs'] = False
        
        # set the sampling rate
        fparams['fs'] = 5 # this gets overwritten by json fs variable (if it exists) that is saved in preprocessing
        json_fpath = os.path.join(fparams['fdir'], fparams['fname']+".json")
        if os.path.exists(json_fpath):
            json_data = utils.open_json(json_fpath)
            if 'fs' in json_data:
                fparams['fs'] = json_data['fs']

        fparams['selected_conditions'] = None # set to None if want to include all conditions from behav data
        
        # trial windowing and normalization
        fparams['trial_start_end'] = [-2, 8] # [start, end] times (in seconds) included in the visualization 
        fparams['flag_zscore'] = True # whether or not to z-score data for plots
        fparams['baseline_end'] = -0.2 # baseline epoch end time (in seconds) for performing baseline normalization
        fparams['event_dur'] = 2 # duration of stim/event in seconds; displays a line below main plot indicating event duration
        fparams['event_sort_analysis_win'] = [0, 5] # time window (in seconds)
        
        # session info
        fparams['opto_blank_frame'] = False # if PMTs were blanked during stim, set stim times to nan (instead of 0)

        # ROI sorting; if flag_sort_rois is set to True, ROIs are sorted by activity in the fparams['event_sort_analysis_win'] window
        fparams['flag_sort_rois'] = True
        if fparams['flag_sort_rois']:
            
            fparams['user_sort_method'] = 'max_value' # peak_time or max_value
            fparams['roi_sort_cond'] = 'plus' # for roi-resolved heatmaps, which condition to sort ROIs by
            
        # errorbar and saving figures
        fparams['flag_roi_trial_avg_errbar'] = True # toggle to show error bar on roi- and trial-averaged traces
        fparams['flag_trial_avg_errbar'] = True # toggle to show error bars on the trial-avg traces
        fparams['interesting_rois'] = [] #[ 0, 1, 2, 23, 22, 11, 9, 5, 6, 7, 3, 4, 8, 12, 14, 15, 16, 17] # [35, 30, 20, 4] #
    
    elif method == 'f2a': # if string is empty, load predefined list of files in files_to_analyze_event
        
        #the files_to_anaylize_event method does not seem to exist
        fparams = files_to_analyze_event.define_fparams()

    elif method == 'root_dir':
        
        pass

    with open(os.path.join(fparams['fdir'], 'event_analysis_fparam.json'), 'w') as fp:
        json.dump(fparams, fp)
    
    return fparams

fparams = define_params(method = 'single') # options are 'single', 'f2a', 'root_dir'

In [None]:
# declare paths
signals_fpath = os.path.join(fparams['fdir'], fparams['fname_signal'])
events_file_path = os.path.join(fparams['fdir'], fparams['fname_events'])

save_dir = os.path.join(fparams['fdir'], 'event_rel_analysis')

utils.check_exist_dir(save_dir); # make the save directory

In [None]:
def declare_paths(fparams):
    paths = {}
    paths['signals_fpath'] = os.path.join(fparams['fdir'], fparams['fname_signal'])
    paths['events_file_path'] = os.path.join(fparams['fdir'], fparams['fname_events'])
    paths['save_dir'] = os.path.join(fparams['fdir'], 'event_rel_analysis')
    
    utils.check_exist_dir(paths['save_dir'])

    return paths

paths = declare_paths(fparams)

In [None]:
### create variables that reference samples and times for slicing and plotting the data

trial_start_end_sec = np.array(fparams['trial_start_end']) # trial windowing in seconds relative to ttl-onset/trial-onset, in seconds
baseline_start_end_sec = np.array([trial_start_end_sec[0], fparams['baseline_end']])

# convert times to samples and get sample vector for the trial 
trial_begEnd_samp = trial_start_end_sec*fparams['fs'] # turn trial start/end times to samples
trial_svec = np.arange(trial_begEnd_samp[0], trial_begEnd_samp[1])
# and for baseline period
baseline_begEnd_samp = baseline_start_end_sec*fparams['fs']
baseline_svec = (np.arange(baseline_begEnd_samp[0], baseline_begEnd_samp[1]+1, 1) - baseline_begEnd_samp[0]).astype('int')

# calculate time vector for plot x axes
num_samples_trial = len( trial_svec )
tvec = np.round(np.linspace(trial_start_end_sec[0], trial_start_end_sec[1], num_samples_trial+1), 2)

# find samples and calculations for time 0 for plotting
t0_sample = utils.get_tvec_sample(tvec, 0) # grabs the sample index of a given time from a vector of times
event_end_sample = int(np.round(t0_sample+fparams['event_dur']*fparams['fs']))
event_bound_ratio = [(t0_sample)/num_samples_trial , event_end_sample/num_samples_trial] # fraction of total samples for event start and end; only used for plotting line indicating event duration

In [None]:
def event_timing(fparams):
    event_data = {}

    event_data['trial_start_end_sec'] = np.array(fparams['trial_start_end'])
    event_data['baseline_start_end_sec'] = np.array([event_data['trial_start_end_sec'][0], fparams['baseline_end']])

    event_data['trial_begEnd_samp'] = event_data['trial_start_end_sec']*fparams['fs']
    event_data['trial_svec'] = np.arange(event_data['trial_begEnd_samp'][0], event_data['trial_begEnd_samp'][1])

    event_data['baseline_begEnd_samp'] = event_data['baseline_start_end_sec']*fparams['fs']
    event_data['baseline_svec'] = (np.arange(event_data['baseline_begEnd_samp'][0], event_data['baseline_begEnd_samp'][1]+1, 1) - event_data['baseline_begEnd_samp'][0]).astype('int') 

    event_data['num_samples_trial'] = len( event_data['trial_svec'] )
    event_data['tvec'] = np.round(np.linspace(event_data['trial_start_end_sec'][0], event_data['trial_start_end_sec'][1], event_data['num_samples_trial']+1), 2)

    event_data['t0_sample'] = utils.get_tvec_sample(event_data['tvec'], 0)
    event_data['event_end_sample'] = int(np.round(event_data['t0_sample']+fparams['event_dur']*fparams['fs']))
    event_data['event_bound_ratio'] = [(event_data['t0_sample'])/event_data['num_samples_trial'] , event_data['event_end_sample']/event_data['num_samples_trial']]

    return event_data

event_data = event_timing(fparams)

In [None]:
signals = utils.load_signals(signals_fpath)

num_rois = signals.shape[0]
all_nan_rois = np.where(np.apply_along_axis(is_all_nans, 1, signals)) # find rois with activity as all nans

# if opto stim frames were detected in preprocessing, set these frames to be NaN (b/c of stim artifact)
if fparams['opto_blank_frame']:
    try:
        glob_stim_files = glob.glob(os.path.join(fparams['fdir'], "{}*_stimmed_frames.pkl".format(fparams['fname'])))
        stim_frames = pickle.load( open( glob_stim_files[0], "rb" ) )
        signals[:,stim_frames['samples']] = None # blank out stimmed frames
        flag_stim = True
        print('Detected stim data; replaced stim samples with NaNs')
    except:
        flag_stim = False
        print('Note: No stim preprocessed meta data detected.')

In [None]:
def roi_differentiation(fparams, paths):
    roi_info = {}

    roi_info['signals'] = utils.load_signals(paths['signals_fpath'])

    roi_info['num_rois'] = roi_info['signals'].shape[0]
    roi_info['all_nan_rois'] = np.where(np.apply_along_axis(is_all_nans, 1, roi_info['signals'])) # find rois with activity as all nans

    # if opto stim frames were detected in preprocessing, set these frames to be NaN (b/c of stim artifact)
    if fparams['opto_blank_frame']:
        try:
            glob_stim_files = glob.glob(os.path.join(fparams['fdir'], "{}*_stimmed_frames.pkl".format(fparams['fname'])))
            stim_frames = pickle.load( open( glob_stim_files[0], "rb" ) )
            roi_info['signals'][:,stim_frames['samples']] = None # blank out stimmed frames
            flag_stim = True
            print('Detected stim data; replaced stim samples with NaNs')
        except:
            flag_stim = False
            print('Note: No stim preprocessed meta data detected.')

    return roi_info
    
roi_info = roi_differentiation(fparams, paths)

In [None]:
### load behavioral data and trial info
try:
    glob_event_files = glob.glob(events_file_path) # look for a file in specified directory
    if 'csv' in glob_event_files[0]:
        event_times = utils.df_to_dict(glob_event_files[0])
    elif 'pkl' in glob_event_files[0]:
        event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
    event_frames = utils.dict_samples_to_time(event_times, fparams['fs'])
except:
    print('Cannot find behavioral data file or file path is incorrect; utils.extract_trial_data will throw error.')

# identify conditions to analyze
all_conditions = event_frames.keys()
conditions = [ condition for condition in all_conditions if len(event_frames[condition]) > 0 ] # keep conditions that have events

conditions.sort()
if fparams['selected_conditions']:
    conditions = fparams['selected_conditions']

cmap_lines = get_cmap(len(conditions))

In [None]:
def load_behavioral_data(fparams, paths):
    behav_data = {}

    try:
        glob_event_files = glob.glob(paths['events_file_path']) # look for a file in specified directory
        if 'csv' in glob_event_files[0]:
            event_times = utils.df_to_dict(glob_event_files[0])
        elif 'pkl' in glob_event_files[0]:
            event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
        behav_data['event_frames'] = utils.dict_samples_to_time(event_times, fparams['fs'])
    except:
        print('Cannot find behavioral data file or file path is incorrect; utils.extract_trial_data will throw error.')

    # identify conditions to analyze
    all_conditions = behav_data['event_frames'].keys()
    behav_data['conditions'] = [ condition for condition in all_conditions if len(behav_data['event_frames'][condition]) > 0 ] # keep conditions that have events

    behav_data['conditions'].sort()
    if fparams['selected_conditions']:
        behav_data['conditions'] = fparams['selected_conditions']

    behav_data['cmap_lines'] = get_cmap(len(behav_data['conditions']))

behav_data = load_behavioral_data(fparams, paths)

## Start trial-based preprocessing

In [None]:
"""
MAIN data processing function to extract event-centered data

extract and save trial data, 
saved data are in the event_rel_analysis subfolder, a pickle file that contains the extracted trial data
"""
data_dict = utils.extract_trial_data(roi_info['signals'],  event_data['tvec'],  event_data['trial_begEnd_samp'],  event_data['event_frames'], 
                                     behav_data['conditions'], baseline_start_end_samp =  event_data['baseline_begEnd_samp'], save_dir=save_dir)

In [None]:
### calculate all the color limits for heatmaps; useful for locking color limits across different heatmap subplots

# for trial_avg data, get min/max across conditions
clims_data = [ np.nanmin( [np.mean(data_dict[key]['data'], axis = 0) for key in data_dict] ), 
        np.nanmax( [np.mean(data_dict[key]['data'], axis = 0) for key in data_dict] ) ]

## for z-scored data, we'd like for the color scale to be centered at 0; first we get color limits
tmp_clim = [ np.nanmin( [data_dict[key]['ztrial_avg_data'] for key in data_dict] ), 
        np.nanmax( [data_dict[key]['ztrial_avg_data'] for key in data_dict] ) ]
# then we take the higher of the two magnitudes
clims_max = np.max(np.abs(tmp_clim))
# and set it as the negative and positive limit for plotting
clims_z = [-clims_max*0.5, clims_max*0.5]

In [None]:
def color_init(data_dict):
    color_limits = {}

    color_limits['clims_data'] = [ np.nanmin( [np.mean(data_dict[key]['data'], axis = 0) for key in data_dict] ), 
            np.nanmax( [np.mean(data_dict[key]['data'], axis = 0) for key in data_dict] ) ]

    tmp_clim = [ np.nanmin( [data_dict[key]['ztrial_avg_data'] for key in data_dict] ), 
            np.nanmax( [data_dict[key]['ztrial_avg_data'] for key in data_dict] ) ]

    clims_max = np.max(np.abs(tmp_clim))

    color_limits['clims_z'] = [-clims_max*0.5, clims_max*0.5]    

    return color_limits   

color_limits = color_init(data_dict)

## Plot trial-resolved heatmap for each ROI

In [None]:
def subplot_trial_heatmap(data_in, conditions, tvec, event_bound_ratio, clims, n_rows, n_columns, 
                           save_fig = False, axis_label_size=15):
    
    """
    Prep information about specific condition (each loop) and plot heatmap
        1) x/y label tick values
        2) x/y labels
        3) grab specific condition's data
        4) plot data (using utils function)
        5) plot meta data lines (eg. 0-time line, event duration line)

    event_bound_ratio
    """
    
    for idx_cond, cond in enumerate(conditions):
        
        # set imshow extent to replace x and y axis ticks/labels
        plot_extent = [tvec[0], tvec[-1], data_in[cond]['num_trials'], 0] # [x min, x max, y min, ymax]
        
        # determine subplot location index
        subplot_index = subplot_loc(idx_cond, n_rows, n_columns)
        
        # prep labels; plot x and y labels for first subplot
        if subplot_index == (0, 0) or subplot_index == 0 :
            ax[subplot_index].set_ylabel('Trial', fontsize=axis_label_size)
            ax[subplot_index].set_xlabel('Time [s]', fontsize=axis_label_size);
        ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
        
        # prep the data
        to_plot = np.squeeze(data_in[cond]['data'][...,iROI,:]) 
        if len(event_frames[cond]) == 1: # accomodates single trial data
            to_plot = to_plot[np.newaxis, :]
        
        # plot the data
        title = 'ROI {}; {}'.format(str(iROI), cond)
        im = utils.subplot_heatmap(ax[subplot_index], title, to_plot, cmap='inferno', clims=clims, extent_=plot_extent)
        
        # add meta data lines
        ax[subplot_index].axvline(0, color='0.5', alpha=1) # plot vertical line for time zero
        # plots green horizontal line indicating event duration
        ax[subplot_index].annotate('', xy=(event_bound_ratio[0], -0.01), xycoords='axes fraction', 
                                   xytext=(event_bound_ratio[1], -0.01), 
                                   arrowprops=dict(arrowstyle="-", color='g'))

    cbar = fig.colorbar(im, ax = ax[subplot_index], shrink = 0.5)
    cbar.ax.set_ylabel('Activity', fontsize = axis_label_size)
    

In [None]:
num_subplots = len(conditions) + 1 # plus one for trial-avg traces
n_columns = np.min([num_subplots, 4.0])
n_rows = int(np.ceil(num_subplots/n_columns))

for iROI in range(num_rois):
     
    ## Plot heatmaps for each condition
    roi_clims = [ np.nanmin( [np.nanmin(data_dict[cond]['data'][...,iROI,:]) for cond in conditions] ), 
        np.nanmax( [np.nanmax(data_dict[cond]['data'][...,iROI,:]) for cond in conditions] ) ]
    
    fig, ax = plt.subplots(nrows=n_rows, ncols=int(n_columns), 
                           figsize=(n_columns*4, n_rows*3),
                           constrained_layout=True)
    
    subplot_trial_heatmap(data_dict, conditions, tvec, event_bound_ratio, roi_clims, n_rows, n_columns, 
                           save_fig=False)
    
    ## plot last subplot of trial-avg traces
    
    # determine subplot location index
    subplot_index = subplot_loc(num_subplots-1, n_rows, n_columns)

    for cond in conditions:
        
        # prep data to plot
        num_trials = data_dict[cond]['num_trials']
        to_plot = np.nanmean(data_dict[cond]['zdata'][:,iROI,:], axis=0)
        to_plot_err = np.nanstd(data_dict[cond]['zdata'][:,iROI,:], axis=0)/np.sqrt(num_trials)
        
        # plot trace
        ax[subplot_index].plot(tvec, to_plot)
        if fparams['opto_blank_frame']: 
            ax[subplot_index].plot(tvec[t0_sample:event_end_sample], to_plot[t0_sample:event_end_sample], marker='.', color='g')
        # plot shaded error
        if fparams['flag_trial_avg_errbar']:
            ax[subplot_index].fill_between(tvec, to_plot - to_plot_err, to_plot + to_plot_err,
                         alpha=0.5) # this plots the shaded error bar
        
    # plot x, y labels, and legend
    ax[subplot_index].set_ylabel('Z-Score Activity', fontsize=axis_label_size)
    ax[subplot_index].set_xlabel('Time [s]', fontsize=axis_label_size)
    ax[subplot_index].set_title('ROI # {}; Trial-avg'.format(str(iROI)), fontsize=axis_label_size)
    ax[subplot_index].legend(conditions)
    ax[subplot_index].autoscale(enable=True, axis='both', tight=True)
    ax[subplot_index].axvline(0, color='0.5', alpha=0.65) # plot vertical line for time zero
    ax[subplot_index].annotate('', xy=(event_bound_ratio[0], -0.01), xycoords='axes fraction', 
                                   xytext=(event_bound_ratio[1], -0.01), 
                                   arrowprops=dict(arrowstyle="-", color='g'))
    ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
    
    for a in ax.flat[num_subplots:]:
        a.axis('off')
    
    if fparams['flag_save_figs']:
        fig.savefig( os.path.join(save_dir,'roi_{}_activity.png'.format(str(iROI))) ); 
        fig.savefig( os.path.join(save_dir,'roi_{}_activity.pdf'.format(str(iROI))) );

In [None]:
def condition_heatmap_plot(data_dict, fparams, event_data, roi_info, behav_data):
    num_subplots = len(behav_data['conditions']) + 1 # plus one for trial-avg traces
    n_columns = np.min([num_subplots, 4.0])
    n_rows = int(np.ceil(num_subplots/n_columns))

    for iROI in range(roi_info['num_rois']):
        
        ## Plot heatmaps for each condition
        roi_clims = [ np.nanmin( [np.nanmin(data_dict[cond]['data'][...,iROI,:]) for cond in behav_data['conditions']] ), 
            np.nanmax( [np.nanmax(data_dict[cond]['data'][...,iROI,:]) for cond in behav_data['conditions']] ) ]
        
        fig, ax = plt.subplots(nrows=n_rows, ncols=int(n_columns), 
                            figsize=(n_columns*4, n_rows*3),
                            constrained_layout=True)
        
        subplot_trial_heatmap(data_dict, behav_data['conditions'], event_data['tvec'], event_data['event_bound_ratio'], roi_clims, n_rows, n_columns, 
                            save_fig=False)
        
        ## plot last subplot of trial-avg traces
        
        # determine subplot location index
        subplot_index = subplot_loc(num_subplots-1, n_rows, n_columns)

        for cond in behav_data['conditions']:
            
            # prep data to plot
            num_trials = data_dict[cond]['num_trials']
            to_plot = np.nanmean(data_dict[cond]['zdata'][:,iROI,:], axis=0)
            to_plot_err = np.nanstd(data_dict[cond]['zdata'][:,iROI,:], axis=0)/np.sqrt(num_trials)
            
            # plot trace
            ax[subplot_index].plot(event_data['tvec'], to_plot)
            if fparams['opto_blank_frame']: 
                ax[subplot_index].plot(event_data['tvec'][event_data['t0_sample']:event_data['event_end_sample']], to_plot[event_data['t0_sample']:event_data['event_end_sample']], marker='.', color='g')
            # plot shaded error
            if fparams['flag_trial_avg_errbar']:
                ax[subplot_index].fill_between(event_data['tvec'], to_plot - to_plot_err, to_plot + to_plot_err,
                            alpha=0.5) # this plots the shaded error bar
            
        # plot x, y labels, and legend
        ax[subplot_index].set_ylabel('Z-Score Activity', fontsize=axis_label_size)
        ax[subplot_index].set_xlabel('Time [s]', fontsize=axis_label_size)
        ax[subplot_index].set_title('ROI # {}; Trial-avg'.format(str(iROI)), fontsize=axis_label_size)
        ax[subplot_index].legend(event_data['conditions'])
        ax[subplot_index].autoscale(enable=True, axis='both', tight=True)
        ax[subplot_index].axvline(0, color='0.5', alpha=0.65) # plot vertical line for time zero
        ax[subplot_index].annotate('', xy=(event_data['event_bound_ratio'][0], -0.01), xycoords='axes fraction', 
                                    xytext=(event_data['event_bound_ratio'][1], -0.01), 
                                    arrowprops=dict(arrowstyle="-", color='g'))
        ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
        
        for a in ax.flat[num_subplots:]:
            a.axis('off')
        
        if fparams['flag_save_figs']:
            fig.savefig( os.path.join(fparams['save_dir'],'roi_{}_activity.png'.format(str(iROI))) ); 
            fig.savefig( os.path.join(fparams['save_dir'],'roi_{}_activity.pdf'.format(str(iROI))) );

In [None]:
# anonymous function to find closest sample when a time occurs in a time vector
tvec = event_data['tvec']
tvec2samp = lambda tvec, time: np.argmin(np.abs(tvec - time))

# function to sort ROIs based on activity in certain epoch
def sort_heatmap_peaks(data, tvec, sort_epoch_start_time, sort_epoch_end_time, sort_method = 'peak_time'):
    
    # find start/end samples for epoch
    sort_epoch_start_samp = tvec2samp(tvec, sort_epoch_start_time)
    sort_epoch_end_samp = tvec2samp(tvec, sort_epoch_end_time)
    
    if sort_method == 'peak_time':
        epoch_peak_samp = np.argmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(epoch_peak_samp)
    elif sort_method == 'max_value':
 
        time_max = np.nanmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(time_max)[::-1]

    return final_sorting

In [None]:
# if flag is true, sort ROIs (usually by average fluorescence within analysis window)
if fparams['flag_sort_rois']:
    if not fparams['roi_sort_cond']: # if no condition to sort by specified, use first condition
        fparams['roi_sort_cond'] = data_dict.keys()[0]
    if not fparams['roi_sort_cond'] in data_dict.keys():
        sorted_roi_order = range(num_rois)
        interesting_rois = fparams['interesting_rois']
        print('Specified condition to sort by doesn\'t exist! ROIs are in default sorting.')
    else:
        # returns new order of rois sorted using the data and method supplied in the specified window
        sorted_roi_order = sort_heatmap_peaks(data_dict[fparams['roi_sort_cond']]['ztrial_avg_data'], tvec, 
                           sort_epoch_start_time=0, 
                           sort_epoch_end_time = trial_start_end_sec[-1], 
                           sort_method = fparams['user_sort_method'])
        # finds corresponding interesting roi (roi's to mark with an arrow) order after sorting
        interesting_rois = np.in1d(sorted_roi_order, fparams['interesting_rois']).nonzero()[0] 
else:
    sorted_roi_order = range(num_rois)
    interesting_rois = fparams['interesting_rois']

if not all_nan_rois[0].size == 0:
    set_diff_keep_order = lambda main_list, remove_list : [i for i in main_list if i not in remove_list]
    sorted_roi_order = set_diff_keep_order(sorted_roi_order, all_nan_rois)
    interesting_rois = [i for i in fparams['interesting_rois'] if i not in all_nan_rois]
    
roi_order_path = os.path.join(fparams['fdir'], fparams['fname'] + '_roi_order.pkl')
with open(roi_order_path, 'wb') as handle:
     pickle.dump(sorted_roi_order, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
def sort_rois(data_dict, fparams, event_data, roi_info):
    # if flag is true, sort ROIs (usually by average fluorescence within analysis window)
    if fparams['flag_sort_rois']:
        if not fparams['roi_sort_cond']: # if no condition to sort by specified, use first condition
            fparams['roi_sort_cond'] = data_dict.keys()[0]
        if not fparams['roi_sort_cond'] in data_dict.keys():
            sorted_roi_order = range(roi_info['num_rois'])
            interesting_rois = fparams['interesting_rois']
            print('Specified condition to sort by doesn\'t exist! ROIs are in default sorting.')
        else:
            # returns new order of rois sorted using the data and method supplied in the specified window
            sorted_roi_order = sort_heatmap_peaks(data_dict[fparams['roi_sort_cond']]['ztrial_avg_data'], event_data['tvec'], 
                            sort_epoch_start_time=0, 
                            sort_epoch_end_time = trial_start_end_sec[-1], 
                            sort_method = fparams['user_sort_method'])
            # finds corresponding interesting roi (roi's to mark with an arrow) order after sorting
            interesting_rois = np.in1d(sorted_roi_order, fparams['interesting_rois']).nonzero()[0] 
    else:
        sorted_roi_order = range(num_rois)
        interesting_rois = fparams['interesting_rois']

    if not all_nan_rois[0].size == 0:
        set_diff_keep_order = lambda main_list, remove_list : [i for i in main_list if i not in remove_list]
        sorted_roi_order = set_diff_keep_order(sorted_roi_order, all_nan_rois)
        interesting_rois = [i for i in fparams['interesting_rois'] if i not in all_nan_rois]
        
    roi_order_path = os.path.join(fparams['fdir'], fparams['fname'] + '_roi_order.pkl')
    with open(roi_order_path, 'wb') as handle:
        pickle.dump(sorted_roi_order, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return interesting_rois

In [None]:
def plot_trial_avg_heatmap(data_in, conditions, tvec, event_bound_ratio, clims, sorted_roi_order = None, 
                           rois_oi = None, save_fig = False, axis_label_size=15):
    
    """
    Technically doesn't need to remove all_nan_rois b/c of nanmean calculations
    """
    
    num_subplots = len(conditions)
    n_columns = np.min([num_subplots, 3.0])
    n_rows = int(np.ceil(num_subplots/n_columns))

    # set imshow extent to replace x and y axis ticks/labels (replace samples with time)
    plot_extent = [tvec[0], tvec[-1], num_rois, 0 ]

    fig, ax = plt.subplots(nrows=n_rows, ncols=int(n_columns), figsize = (n_columns*5, n_rows*4))
    if not isinstance(ax,np.ndarray): # this is here to make the code below compatible with indexing a single subplot object
        ax = [ax]

    for idx, cond in enumerate(conditions):

        # determine subplot location index
        if n_rows == 1:
            subplot_index = idx
        else:
            subplot_index = np.unravel_index(idx, (n_rows, int(n_columns))) # turn int index to a tuple of array coordinates

        # prep labels; plot x and y labels for first subplot
        if subplot_index == (0, 0) or subplot_index == 0 :
            ax[subplot_index].set_ylabel('ROI #', fontsize=axis_label_size)
            ax[subplot_index].set_xlabel('Time [s]', fontsize=axis_label_size);
        ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
        
        # plot the data
        if sorted_roi_order is not None:
            roi_order = sorted_roi_order
        else:
            roi_order = slice(0, num_rois)
        to_plot = data_in[cond]['ztrial_avg_data'][roi_order,:] # 

        im = utils.subplot_heatmap(ax[subplot_index], cond, to_plot, clims = clims, extent_=plot_extent)
        ax[subplot_index].axvline(0, color='k', alpha=0.3) # plot vertical line for time zero
        ax[subplot_index].annotate('', xy=(event_bound_ratio[0], -0.01), xycoords='axes fraction', 
                                       xytext=(event_bound_ratio[1], -0.01), 
                                       arrowprops=dict(arrowstyle="-", color='g'))
        if rois_oi is not None:
            for ROI_OI in rois_oi:
                ax[subplot_index].annotate('', xy=(1.005, 1-(ROI_OI/num_rois)-0.015), xycoords='axes fraction', 
                                           xytext=(1.06, 1-(ROI_OI/num_rois)-0.015), 
                                           arrowprops=dict(arrowstyle="->", color='k'))
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    cbar = fig.colorbar(im, ax = ax, shrink = 0.7)
    cbar.ax.set_ylabel('Z-Score Activity', fontsize=13)
    
    # hide empty subplot
    for a in ax.flat[num_subplots:]:
        a.axis('off')
    
    if save_fig:
        fig.savefig(os.path.join(save_dir,'trial_avg_heatmap.png')); 
        fig.savefig(os.path.join(save_dir,'trial_avg_heatmap.pdf'));

plot_trial_avg_heatmap(data_dict, behav_data['conditions'], event_data['tvec'], event_data['event_bound_ratio'], clims = clims_z,
                       sorted_roi_order = sorted_roi_order, rois_oi = interesting_rois, save_fig = fparams['flag_save_figs'])



## Plot trial- and ROI-averaged traces

In [None]:
line_shades = []
fig, axs = plt.subplots(1,1, figsize = (10,6))
for idx, cond in enumerate(conditions):
    line_color = cmap_lines(idx)
    # first trial avg the data
    trial_avg = np.nanmean(data_dict[cond]['zdata'], axis=0)
    
    # z-score trial-avg data for each respective ROI
    # apply zscore function to each row of data
    app_axis = 1 
    zscore_trial_avg = np.apply_along_axis(utils.zscore_, app_axis, trial_avg, baseline_svec)
    
    # take avg/std across ROIs
    zscore_roi_trial_avg = np.nanmean(zscore_trial_avg, axis=0)
    zscore_roi_trial_std = np.nanstd(zscore_trial_avg, axis=0)
     
    to_plot = np.squeeze(zscore_roi_trial_avg)
    to_plot_err = np.squeeze(zscore_roi_trial_std)/np.sqrt(num_rois)
    
    axs.plot(tvec, to_plot, color=line_color)
    if fparams['opto_blank_frame']:
        line = axs.plot(tvec[t0_sample:event_end_sample], to_plot[t0_sample:event_end_sample], marker='.', color=line_color)
    else:
        line = axs.plot(tvec[t0_sample:event_end_sample], to_plot[t0_sample:event_end_sample], color=line_color)
    
    if fparams['flag_roi_trial_avg_errbar']:
        shade = axs.fill_between(tvec, to_plot - to_plot_err, to_plot + to_plot_err, color = line_color,
                     alpha=0.2) # this plots the shaded error bar
        line_shades.append((line[0],shade))
            
axs.set_ylabel('Z-score Activity', fontsize=axis_label_size)
axs.set_xlabel('Time [s]', fontsize=axis_label_size);
axs.legend(conditions);
axs.legend(line_shades, conditions, fontsize=15)
axs.axvline(0, color='0.5', alpha=0.65) # plot vertical line for time zero
axs.annotate('', xy=(event_bound_ratio[0], -0.01), xycoords='axes fraction', 
                               xytext=(event_bound_ratio[1], -0.01), 
                               arrowprops=dict(arrowstyle="-", color='g'))
axs.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size+3)
axs.autoscale(enable=True, axis='both', tight=True)

#axs.set_ylim([-1.5, 10])

if fparams['flag_save_figs']:
        fig.savefig(os.path.join(save_dir,'roi_trial_avg_trace.png')); fig.savefig(os.path.join(save_dir,'roi_trial_avg_trace.pdf'));

In [None]:
def heatmaps_across_rois_plot(fparams, event_data, roi_info, behav_data): 
    line_shades = []
    fig, axs = plt.subplots(1,1, figsize = (10,6))
    for idx, cond in enumerate(behav_data['conditions']):
        line_color = cmap_lines(idx)
        # first trial avg the data
        trial_avg = np.nanmean(data_dict[cond]['zdata'], axis=0)
        
        # z-score trial-avg data for each respective ROI
        # apply zscore function to each row of data
        app_axis = 1 
        zscore_trial_avg = np.apply_along_axis(utils.zscore_, app_axis, trial_avg, event_data['baseline_svec'])
        
        # take avg/std across ROIs
        zscore_roi_trial_avg = np.nanmean(zscore_trial_avg, axis=0)
        zscore_roi_trial_std = np.nanstd(zscore_trial_avg, axis=0)
        
        to_plot = np.squeeze(zscore_roi_trial_avg)
        to_plot_err = np.squeeze(zscore_roi_trial_std)/np.sqrt(roi_info['num_rois'])
        
        axs.plot(event_data['tvec'], to_plot, color=line_color)
        if fparams['opto_blank_frame']:
            line = axs.plot(event_data['tvec'][event_data['t0_sample']:event_data['event_end_sample']], to_plot[event_data['t0_sample']:event_data['event_end_sample']], marker='.', color=line_color)
        else:
            line = axs.plot(event_data['tvec'][event_data['t0_sample']:event_data['event_end_sample']], to_plot[event_data['t0_sample']:event_data['event_end_sample']], color=line_color)
        
        if fparams['flag_roi_trial_avg_errbar']:
            shade = axs.fill_between(event_data['tvec'], to_plot - to_plot_err, to_plot + to_plot_err, color = line_color,
                        alpha=0.2) # this plots the shaded error bar
            line_shades.append((line[0],shade))
                
    axs.set_ylabel('Z-score Activity', fontsize=axis_label_size)
    axs.set_xlabel('Time [s]', fontsize=axis_label_size);
    axs.legend(behav_data['conditions']);
    axs.legend(line_shades, behav_data['conditions'], fontsize=15)
    axs.axvline(0, color='0.5', alpha=0.65) # plot vertical line for time zero
    axs.annotate('', xy=(event_data['event_bound_ratio'][0], -0.01), xycoords='axes fraction', 
                                xytext=(event_data['event_bound_ratio'][1], -0.01), 
                                arrowprops=dict(arrowstyle="-", color='g'))
    axs.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size+3)
    axs.autoscale(enable=True, axis='both', tight=True)

    #axs.set_ylim([-1.5, 10])

    if fparams['flag_save_figs']:
            fig.savefig(os.path.join(fparams['save_dir'],'roi_trial_avg_trace.png')); fig.savefig(os.path.join(fparams['save_dir'],'roi_trial_avg_trace.pdf'));

In [None]:
# function for finding the index of the closest entry in an array to a provided value
def find_nearest_idx(array, value):

    if isinstance(array, pd.Series):
        idx = (np.abs(array - value)).idxmin()
        return idx, array.index.get_loc(idx), array[idx] # series index, 0-relative index, entry value
    else:
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return idx, array[idx]

In [None]:
### Quantification of roi-, trial-, time-averaged data

analysis_window = fparams['event_sort_analysis_win']
analysis_win_samps = [ find_nearest_idx(tvec, time)[0] for time in analysis_window ]

to_plot = []
to_plot_err = []

fig, axs = plt.subplots(1,1, figsize = (5,5))
for idx, cond in enumerate(conditions):
    line_color = cmap_lines(idx)
    # first trial avg the data
    trial_avg = np.nanmean(data_dict[cond]['zdata'], axis=0)
    
    # z-score trial-avg data for each respective ROI
    # apply zscore function to each row of data
    apply_axis = 1 
    zscore_trial_avg = np.apply_along_axis(utils.zscore_, apply_axis, trial_avg, baseline_svec)
    
    # take avg across time
    zscore_trial_time_avg = np.nanmean(zscore_trial_avg[:,analysis_win_samps[0]:analysis_win_samps[1],:], axis=1)
    
    # take avg/std across ROIs
    zscore_roi_trial_time_avg = np.nanmean(zscore_trial_time_avg, axis=0)
    zscore_roi_trial_time_std = np.nanstd(zscore_trial_time_avg, axis=0)
     
    to_plot.append(zscore_roi_trial_time_avg[0])
    to_plot_err.append(zscore_roi_trial_time_std[0]/np.sqrt(len(zscore_trial_time_avg)))
    
barlist = axs.bar(conditions, to_plot, yerr=to_plot_err, align='center', alpha=0.5, ecolor='black', capsize=10 )
for idx in range(len(conditions)):
    barlist[idx].set_color(cmap_lines(idx))
axs.set_ylabel('Normalized Fluorescence', fontsize=13)
axs.set_title('ROI-, Trial-, Time-averaged Quant', fontsize=15)
axs.yaxis.grid(True)
axs.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
axs.tick_params(axis = 'x', which = 'major', rotation = 45)
# Save the figure and show
plt.tight_layout()

if fparams['flag_save_figs']:
    fig.savefig(os.path.join(save_dir,'roi_trial_time_avg_bar.png')); 
    fig.savefig(os.path.join(save_dir,'roi_trial_time_avg_bar.pdf'));

In [None]:
def time_averaged_quantification_plot(data_dict, fparams, event_data, behav_data):
    analysis_window = fparams['event_sort_analysis_win']
    analysis_win_samps = [ find_nearest_idx(event_data['tvec'], time)[0] for time in analysis_window ]

    to_plot = []
    to_plot_err = []

    fig, axs = plt.subplots(1,1, figsize = (5,5))
    for idx, cond in enumerate(behav_data['conditions']):
        line_color = cmap_lines(idx)
        # first trial avg the data
        trial_avg = np.nanmean(data_dict[cond]['zdata'], axis=0)
        
        # z-score trial-avg data for each respective ROI
        # apply zscore function to each row of data
        apply_axis = 1 
        zscore_trial_avg = np.apply_along_axis(utils.zscore_, apply_axis, trial_avg, event_data['baseline_svec'])
        
        # take avg across time
        zscore_trial_time_avg = np.nanmean(zscore_trial_avg[:,analysis_win_samps[0]:analysis_win_samps[1],:], axis=1)
        
        # take avg/std across ROIs
        zscore_roi_trial_time_avg = np.nanmean(zscore_trial_time_avg, axis=0)
        zscore_roi_trial_time_std = np.nanstd(zscore_trial_time_avg, axis=0)
        
        to_plot.append(zscore_roi_trial_time_avg[0])
        to_plot_err.append(zscore_roi_trial_time_std[0]/np.sqrt(len(zscore_trial_time_avg)))
        
    barlist = axs.bar(behav_data['conditions'], to_plot, yerr=to_plot_err, align='center', alpha=0.5, ecolor='black', capsize=10 )
    for idx in range(len(behav_data['conditions'])):
        barlist[idx].set_color(cmap_lines(idx))
    axs.set_ylabel('Normalized Fluorescence', fontsize=13)
    axs.set_title('ROI-, Trial-, Time-averaged Quant', fontsize=15)
    axs.yaxis.grid(True)
    axs.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
    axs.tick_params(axis = 'x', which = 'major', rotation = 45)
    # Save the figure and show
    plt.tight_layout()

    if fparams['flag_save_figs']:
        fig.savefig(os.path.join(fparams['save_dir'],'roi_trial_time_avg_bar.png')); 
        fig.savefig(os.path.join(fparams['save_dir'],'roi_trial_time_avg_bar.pdf'));