In [None]:
import os
import numpy as np
import glob
import pickle
import json
import seaborn as sns
import matplotlib.ticker as ticker
import pandas as pd
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt
import matplotlib
#important for text to be detected when importing saved figures into illustrator
matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
plt.rcParams["font.family"] = "Arial"

import utils

In [None]:
# function to find closest sample when a time occurs in a time vector
tvec2samp = lambda tvec, time: np.argmin(np.abs(tvec - time))

# function to sort ROIs based on activity in certain epoch
def sort_heatmap_peaks(data, tvec, sort_epoch_start_time, sort_epoch_end_time, sort_method = 'peak_time'):
    
    # find start/end samples for epoch
    sort_epoch_start_samp = tvec2samp(tvec, sort_epoch_start_time)
    sort_epoch_end_samp = tvec2samp(tvec, sort_epoch_end_time)
    
    if sort_method == 'peak_time':
        epoch_peak_samp = np.argmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(epoch_peak_samp)
    elif sort_method == 'max_value':
 
        time_max = np.nanmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(time_max)[::-1]
    elif sort_method == 'mean_value':
        epoch_peak_samp = np.mean(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.flip(np.argsort(epoch_peak_samp))
        
    return final_sorting

def is_all_nans(vector):
    """
    checks if series or vector contains all nans; returns boolean. Used to identify and exclude all-nan rois
    """
    if isinstance(vector, pd.Series):
        vector = vector.values
    return np.isnan(vector).all()

# declare some fixed constant variables
axis_label_size = 15
tick_font_size = 14 

In [None]:
"""
USER-DEFINED VARIABLES
"""

def define_params(method = 'single'):
    
    fparams = {}
    
    if method == 'single':

        # default sample data generously provided by Cat Zamorano
        fparams['fname_signal'] = 'heatmap_slice_s2p_neuropil_corrected_signals.npy'   # 
        fparams['fname_events'] = 'heatmap_slice_events.csv' # can set to None if you want to plot the signals only
        # fdir signifies to the root path of the data. Currently, the abspath phrase points to sample data from the repo.
        # To specify a path that is on your local computer, use this string format: r'your_root_path', where you should copy/paste
        # your path between the single quotes (important to keep the r to render as a complete raw string). See example below:
        # r'C:\Users\stuberadmin\Documents\GitHub\NAPE_imaging_postprocess\napeca_post\sample_data' 
        fparams['fdir'] = os.path.abspath('./sample_data/heatmap_slice')
        fparams['fname'] = os.path.split(fparams['fdir'])[1]
        fparams['flag_save_figs'] = False
        fparams['flag_close_figs_after_save'] = False
        
        # set the sampling rate
        fparams['fs'] = 0.46

        # session info
        fparams['opto_blank_frame'] = False # if PMTs were blanked during stim, set stim times to nan (instead of 0)
        
        # analysis and plotting arguments
        fparams['num_rois'] = 'all' # set to 'all' if want to show all cells
        fparams['selected_conditions'] = None # set to None if want to include all conditions from behav data
        fparams['flag_normalization'] = 'zscore' # options: 'dff', 'zscore', 'dff_perc', None
        fparams['baseline_epoch'] = [0, 60]
        fparams['event_sort_analysis_win'] = [200, 400] # time window (in seconds) for sorting cells; list [start, end]
        fparams['interesting_rois'] = []
        
        # ROI sorting; if flag_sort_rois is set to True, ROIs are sorted by activity in the fparams['event_sort_analysis_win'] window
        fparams['flag_sort_rois'] = True
        if fparams['flag_sort_rois']:
            fparams['user_sort_method'] = 'mean_value' # peak_time, mean_value or max_value
            fparams['roi_sort_cond'] = 'damgo' # for roi-resolved heatmaps, which condition to sort ROIs by
            
    elif method == 'f2a': # if string is empty, load predefined list of files in files_to_analyze_event

        fparams = files_to_analyze_event.define_fparams()

    elif method == 'root_dir':
        
        pass

    return fparams

fparams = define_params(method = 'single') # options are 'single', 'f2a', 'root_dir'

cond_colors = ['b', 'c', 'r', 'm', 'g', 'y']

if 'zscore' == fparams['flag_normalization']:
    data_trial_resolved_key = 'zdata'
    data_trial_avg_key = 'ztrial_avg_data'
    cmap_ = None
    ylabel = 'Z-score Activity'
else:
    data_trial_resolved_key = 'data'
    data_trial_avg_key = 'trial_avg_data'
    cmap_ = 'inferno'
    ylabel = 'Activity'

In [None]:
# define load and save paths
fext = os.path.splitext(fparams['fname_signal'])[-1]
signals_fpath = os.path.join(fparams['fdir'], fparams['fname_signal'])

if fparams['fname_events']:
    events_file_path = os.path.join(fparams['fdir'], fparams['fname_events'])

save_dir = os.path.join(fparams['fdir'], 'event_rel_analysis')

utils.check_exist_dir(save_dir); # make the save directory

In [None]:
# functions to normalize traces
def calc_dff_percentile(activity_vec, perc=25):
    perc_activity = np.percentile(activity_vec, perc)
    return (activity_vec-perc_activity)/perc_activity

def calc_zscore(data, baseline_samples):
    mean_baseline = np.nanmean(data[..., baseline_samples])
    std_baseline = np.nanstd(data[..., baseline_samples])
    return (data-mean_baseline)/std_baseline

In [None]:
# load time-series data
signals = utils.load_signals(signals_fpath)
all_nan_rois = np.where(np.apply_along_axis(is_all_nans, 1, signals)) # find rois with activity as all nans

if fparams['opto_blank_frame']:
    try:
        glob_stim_files = glob.glob(os.path.join(fparams['fdir'], "{}*_stimmed_frames.pkl".format(fparams['fname'])))
        stim_frames = pickle.load( open( glob_stim_files[0], "rb" ) )
        signals[:,stim_frames['samples']] = None # blank out stimmed frames
        flag_stim = True
        print('Detected stim data; replaced stim samples with NaNs')
    except:
        flag_stim = False
        print('Note: No stim preprocessed meta data detected.')

if fparams['flag_normalization'] == 'dff':
    signal_to_plot = np.apply_along_axis(utils.calc_dff, 1, signals)
elif fparams['flag_normalization'] == 'dff_perc':
    signal_to_plot = np.apply_along_axis(calc_dff_percentile, 1, signals)
elif fparams['flag_normalization'] == 'zscore':
    if fparams['baseline_epoch']:
        baseline_edge_samples = np.array(fparams['baseline_epoch'])*fparams['fs']
        signal_to_plot = np.apply_along_axis(calc_zscore, 1, signals, np.arange(baseline_edge_samples[0], baseline_edge_samples[1]).astype('int'))
    else:
        signal_to_plot = np.apply_along_axis(calc_zscore, 1, signals, np.arange(0, signals.shape[1]))
else:
    signal_to_plot = signals

min_max = [list(min_max_tup) for min_max_tup in zip(np.min(signal_to_plot,axis=1), np.max(signal_to_plot,axis=1))]
min_max_all = [np.min(signal_to_plot), np.max(signal_to_plot)]

In [None]:
if fparams['num_rois'] == 'all':
    fparams['num_rois'] = signals.shape[0]

total_session_time = signals.shape[1]/fparams['fs']
tvec = np.round(np.linspace(0, total_session_time, signals.shape[1]), 2)

In [None]:
#load behavioral data and trial info
if fparams['fname_events']:

    glob_event_files = glob.glob(events_file_path) # look for a file in specified directory
    if not glob_event_files:
        print(f'{events_file_path} not detected. Please check if path is correct.')
    if 'csv' in glob_event_files[0]:
        event_times = utils.df_to_dict(glob_event_files[0])
    elif 'pkl' in glob_event_files[0]:
        event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
    event_frames = utils.dict_time_to_samples(event_times, fparams['fs'])


    event_times = {}
    if fparams['selected_conditions']:
        conditions = fparams['selected_conditions'] 
    else:
        conditions = event_frames.keys()
    for cond in conditions: # convert event samples to time in seconds
            event_times[cond] = (np.array(event_frames[cond])/fparams['fs']).astype('int')

In [None]:
# if flag is true, sort ROIs (usually by average fluorescence within analysis window)
if fparams['flag_sort_rois']:
    if not fparams['roi_sort_cond']: # if no condition to sort by specified, use first condition
        fparams['roi_sort_cond'] = conditions[0]
    if not fparams['roi_sort_cond'] in conditions:
        sorted_roi_order = range(fparams['num_rois'])
        interesting_rois = fparams['interesting_rois']
        print('Specified condition to sort by doesn\'t exist! ROIs are in default sorting.')
    else:
        # returns new order of rois sorted using the data and method supplied in the specified window
        sorted_roi_order = sort_heatmap_peaks(signal_to_plot, tvec, 
                           sort_epoch_start_time=fparams['event_sort_analysis_win'][0], 
                           sort_epoch_end_time = fparams['event_sort_analysis_win'][1], 
                           sort_method = fparams['user_sort_method'])
        # finds corresponding interesting roi (roi's to mark with an arrow) order after sorting
        interesting_rois = np.in1d(sorted_roi_order, fparams['interesting_rois']).nonzero()[0] 
else:
    sorted_roi_order = range(fparams['num_rois'])
    interesting_rois = fparams['interesting_rois']

if not all_nan_rois[0].size == 0:
    set_diff_keep_order = lambda main_list, remove_list : [i for i in main_list if i not in remove_list]
    sorted_roi_order = set_diff_keep_order(sorted_roi_order, all_nan_rois)
    interesting_rois = [i for i in fparams['interesting_rois'] if i not in all_nan_rois]

if sorted_roi_order is not None:
    roi_order = sorted_roi_order
else:
    roi_order = slice(0, fparams['num_rois'])

roi_order_path = os.path.join(fparams['fdir'], fparams['fname'] + '_roi_order.pkl')
with open(roi_order_path, 'wb') as handle:
     pickle.dump(sorted_roi_order, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# calculate all the color limits for heatmaps; useful for locking color limits across different heatmap subplots   
def generate_clims(data_in, norm_type, scaling=1):
    # get min and max for all data across conditions 
    clims_out = [np.nanmin(data_in), np.nanmax(data_in)]
    if 'zscore' == norm_type: # if data are zscored, make limits symmetrical and centered at 0
        clims_max = np.max(np.abs(clims_out)) # then we take the higher of the two magnitudes
        clims_out = [-clims_max*scaling, clims_max*scaling] # and set it as the negative and positive limit for plotting
    return clims_out

In [None]:
# set imshow extent to replace x and y axis ticks/labels (replace samples with time)
plot_extent = [tvec[0], tvec[-1], fparams['num_rois'], 0 ]

fig, ax = plt.subplots(nrows=1, ncols=1, figsize = (10,5))

subplot_index = 0

ax.set_ylabel('ROI #', fontsize=axis_label_size)
ax.set_xlabel('Time (s)', fontsize=axis_label_size);
ax.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)

# plot the data
to_plot = signal_to_plot[roi_order,:] # 

im = utils.subplot_heatmap(ax, 'Whole-session Heatmap', to_plot, cmap=cmap_, clims=generate_clims(signal_to_plot, fparams['flag_normalization'], 0.5), extent_=plot_extent)

for cond_idx, cond in enumerate(conditions): # convert event samples to time in seconds
    for event_time in event_times[cond]:
        ax.axvline(event_time, color=cond_colors[cond_idx], alpha=1, linewidth=2, linestyle=(0, (5, 5))) # plot vertical line for time zero

plt.legend(conditions)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
cbar = fig.colorbar(im, ax = ax, shrink = 0.7)
cbar.ax.set_ylabel(ylabel, fontsize=13)

if fparams['flag_save_figs']:
    fig.savefig(os.path.join(save_dir,'trial_avg_heatmap.png')); 
    fig.savefig(os.path.join(save_dir,'trial_avg_heatmap.pdf'));

In [None]:
for iROI in range(fparams['num_rois']):

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize = (10,5))

    to_plot = signal_to_plot[roi_order[iROI],:] # 
    ax.plot(tvec, to_plot, label='_no_legend')

    ax.set_title('ROI # {}'.format(str(iROI)), fontsize=axis_label_size)
    ax.set_ylabel(ylabel, fontsize=axis_label_size)
    ax.set_xlabel('Time (s)', fontsize=axis_label_size);
    ax.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
        
    for cond_idx, cond in enumerate(conditions): # convert event samples to time in seconds
        for event_time in event_times[cond]:
            ax.axvline(event_time, color=cond_colors[cond_idx], alpha=1, linewidth=2, linestyle=(0, (5, 5)), label=cond) # plot vertical line for time zero

    plt.legend()
    plt.autoscale(enable=True, axis='both', tight=True)
    plt.ylim(generate_clims(signal_to_plot, fparams['flag_normalization']))
    
    if fparams['flag_save_figs']:
        fig.savefig( os.path.join(save_dir,'roi_{}_activity.png'.format(str(iROI))) ); 
        fig.savefig( os.path.join(save_dir,'roi_{}_activity.pdf'.format(str(iROI))) );
    
    if fparams['flag_close_figs_after_save']:
        plt.close(fig)