In [None]:
import os
import numpy as np
import h5py
import pandas as pd
import pickle
import numpy as np
from PIL import Image
import glob

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
plt.rcParams['text.usetex'] = False
plt.rcParams['text.latex.unicode'] = False
import matplotlib
# important for text to be detecting when importing saved figures into illustrator
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams["font.family"] = "Arial"

import s2p_plot_rois_activity_funcs
import utils

In [None]:
# USER DEFINED VARIABLES
fdir = r'D:\bruker_data\vj_ofc_imageactivate_02_200\vj_ofc_imageactivate_02_200_006' # NOTE: the root folder name must match the basename of the _sima_masks.npy file
fname_events = 'framenumberforevents_vj_ofc_imageactivate_02_200_006.pickle'
preprocess_mode = 'sima' # 's2p' or 'napeca'
fs = 30

trial_start_end = [-2, 8] # [start, end] times (in seconds) included in the visualization 
flag_normalization = 'zscore' # options: 'zscore', None
baseline_window = [-2, 0]
viz_window = [0, 5]
"""
define number of ROIs to visualize

can be: 
1) a list of select rois, 
2) an integer (n) indicating n first rois to plot, or 
3) None or 'all' which plots all valid ROIs
""" 
rois_to_plot = None # np.arange(5) #[0,2,3,6] 

In [None]:
# set paths
fname = os.path.split(fdir)[-1]
if 's2p' in preprocess_mode:
    path_dict = s2p_plot_rois_activity_funcs.s2p_dir(fdir)
    path_dict = s2p_plot_rois_activity_funcs.define_paths_roi_plots(path_dict, None, None, None)
elif any(x in preprocess_mode for x in ['sima', 'napeca']):
    roi_mask_path = os.path.join(fdir, f'{fname}_sima_masks.npy')
    sima_h5_path = os.path.join(fdir, f'{fname}_sima_mc.h5')
    
events_file_path = os.path.join(fdir, fname_events)
    
fig_save_dir = os.path.join(fdir, 'figs')
if not os.path.exists(fig_save_dir):
    os.mkdir(fig_save_dir)

In [None]:
# load projection image
proj_manual = {}

if any(x in preprocess_mode for x in ['s2p']):
    s2p_data_dict = s2p_plot_rois_activity_funcs.load_s2p_data_roi_plots(path_dict)
    plot_vars = s2p_plot_rois_activity_funcs.plotting_rois(s2p_data_dict, path_dict)
    proj_manual[f'{proj_type}_img'] = s2p_data_dict['ops']['meanImg']

elif any(x in preprocess_mode for x in ['sima', 'napeca']):
    # load video data
    # open h5 to read, find data key, grab data, then close
    h5 = h5py.File(sima_h5_path,'r')
    data = np.squeeze(np.array( h5[list(h5.keys())[0]] )).astype('int16') # np.array loads all data into memory
    h5.close()

    proj_manual = {'mean_img': np.mean(data, axis = 0), 
                   'max_img': np.max(data, axis = 0), 
                   'std_img': np.std(data, axis = 0) }

In [None]:
glob_event_files = glob.glob(events_file_path) # look for a file in specified directory
if not glob_event_files:
    print(f'{events_file_path} not detected. Please check if path is correct.')
if 'csv' in glob_event_files[0]:
    event_times = utils.df_to_dict(glob_event_files[0])
elif any(x in glob_event_files[0] for x in ['pkl', 'pickle']):
    event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
event_frames = event_times # utils.dict_time_to_samples(event_times, fs)

# identify conditions to analyze
all_conditions = event_frames.keys()
conditions = [ condition for condition in all_conditions if len(event_frames[condition]) > 0 ] # keep conditions that have events

In [None]:
### create variables that reference samples and times for slicing and plotting the data

trial_start_end_sec = np.array(trial_start_end) # trial windowing in seconds relative to ttl-onset/trial-onset, in seconds
baseline_start_end_sec = np.array(baseline_window)

# convert times to samples and get sample vector for the trial 
trial_begEnd_samp = trial_start_end_sec*fs # turn trial start/end times to samples
trial_svec = np.arange(trial_begEnd_samp[0], trial_begEnd_samp[1])
# and for baseline period
baseline_begEnd_samp = baseline_start_end_sec*fs
baseline_svec = (np.arange(baseline_begEnd_samp[0], baseline_begEnd_samp[1]+1, 1) - baseline_begEnd_samp[0]).astype('int')
# viz window
viz_window_samp = (viz_window-baseline_start_end_sec[0])*fs
viz_window_svec = np.arange(viz_window_samp[0], viz_window_samp[1])

# calculate time vector for plot x axes
num_samples_trial = len( trial_svec )
tvec = np.round(np.linspace(trial_start_end_sec[0], trial_start_end_sec[1], num_samples_trial+1), 2)


In [None]:
start_end_samp = trial_begEnd_samp
baseline_start_end_samp = baseline_begEnd_samp

# create sample vector for baseline epoch if argument exists (for zscoring)
if baseline_begEnd_samp is not None:
    baseline_svec = (np.arange(baseline_begEnd_samp[0], baseline_begEnd_samp[1] + 1, 1) -
                     baseline_begEnd_samp[0]).astype('int')

data_dict = {}

for idx, condition in enumerate(conditions):

    data_dict[condition] = {}

    # get rid of trials that are outside of the session bounds with respect to time
    data_end_sample = data.shape[0]
    cond_frame_events = utils.remove_trials_out_of_bounds(data_end_sample, event_frames[condition], start_end_samp[0], start_end_samp[1])

    # convert window time bounds to samples and make a trial sample vector
    # make an array where the sample indices are repeated in the y axis for n number of trials
    num_trials_cond = len(cond_frame_events)
    if num_trials_cond == 1:
        svec_tile = np.arange(start_end_samp[0], start_end_samp[1] + 1) # just make a 1D vector for svec
        num_trial_samps = len(svec_tile)
    else:
        svec_tile = utils.make_tile(start_end_samp[0], start_end_samp[1], num_trials_cond)
        num_trial_samps = svec_tile.shape[1]
    
    if num_trials_cond > 0:

        # now make a repeated matrix of each trial's ttl on sample in the x dimension
        ttl_repmat = np.repeat(cond_frame_events[:, np.newaxis], num_trial_samps, axis=1).astype('int')
        # calculate actual trial sample indices by adding the TTL onset repeated matrix and the trial window template
        trial_sample_mat = np.round(ttl_repmat + svec_tile).astype('int')

        # extract frames in trials and reshape the data to be: y,x,trials,samples
        # basically unpacking the last 2 dimensions
        reshape_dim = (svec_tile.shape) + data.shape[-2:]
        extracted_trial_dat = data[np.ndarray.flatten(trial_sample_mat), ...].reshape(reshape_dim)
        
    # save normalized data
    if baseline_start_end_samp is not None:
        # input data dimensions should be (trials, ROI, samples)
        data_dict[condition]['zdata'] = np.squeeze(np.apply_along_axis(utils.zscore_, 1,
                                                                                  extracted_trial_dat,
                                                                                  baseline_svec), axis=2)

In [None]:
from matplotlib.colors import ListedColormap
import seaborn as sns
cmap = ListedColormap(sns.color_palette("RdBu_r", 100))

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (10,10))
im = ax.imshow(np.mean(data_dict[condition]['zdata'][:, viz_window_svec, ...], axis=(0,1)), cmap,
               vmin = -2, vmax = 2)

cbar = fig.colorbar(im, ax = ax, shrink = 0.5)
cbar.ax.set_ylabel('Norm Fluorescence', fontsize = 13)

In [None]:
# import multiprocessing

# import numpy as np

# def parallel_apply_along_axis(func1d, axis, arr, *args, **kwargs):
#     """
#     Like numpy.apply_along_axis(), but takes advantage of multiple
#     cores.
#     """        
#     # Effective axis where apply_along_axis() will be applied by each
#     # worker (any non-zero axis number would work, so as to allow the use
#     # of `np.array_split()`, which is only done on axis 0):
#     effective_axis = 1 if axis == 0 else axis
#     if effective_axis != axis:
#         arr = arr.swapaxes(axis, effective_axis)

#     # Chunks for the mapping (only a few chunks):
#     chunks = [(func1d, effective_axis, sub_arr, args, kwargs)
#               for sub_arr in np.array_split(arr, multiprocessing.cpu_count())]

#     pool = multiprocessing.Pool()
#     individual_results = pool.map(unpacking_apply_along_axis, chunks)
#     # Freeing the workers:
#     pool.close()
#     pool.join()

#     return np.concatenate(individual_results)

# def unpacking_apply_along_axis(all_args):
#     (func1d, axis, arr, args, kwargs) = all_args
    
# tmp = parallel_apply_along_axis(utils.zscore_, 1, extracted_trial_dat, baseline_svec)