In [None]:
import os
import numpy as np
import glob
import pickle
import pandas as pd
from collections import defaultdict
import utils
import matplotlib.pyplot as plt
import warnings
import json

In [None]:
"""
USER-DEFINED VARIABLES
"""

def define_params(method = 'single'):
    
    fparams = {}
    
    if method == 'single':
        
        # fdir signifies to the root path of the data. Currently, the abspath phrase points to sample data from the repo.
        # To specify a path that is on your local computer, use this string format: r'your_root_path', where you should copy/paste
        # your path between the single quotes (important to keep the r to render as a complete raw string). See example below:
        # r'C:\Users\stuberadmin\Documents\GitHub\NAPE_imaging_postprocess\napeca_post\sample_data' 
        fparams['fdir_parent'] = r'D:\olympus_data\NAPECA_TEST_DATA'
        fparams['fname'] = os.path.split(fparams['fdir_parent'])[1]
        
        fparams['flag_close_figs_after_save'] = True
        fparams['figs_save_dir'] = r'D:\olympus_data\NAPECA_TEST_DATA'
        fparams['flag_save_figs'] = True
        
        # set the sampling rate
        fparams['fs'] = 5 # this gets overwritten by json fs variable (if it exists) that is saved in preprocessing

        # set to None if want to include all conditions from behav data; 
        # otherwise, set to list of conditions, eg. ['plus', 'minus']
        fparams['selected_conditions'] = None 
        
        # trial windowing and normalization
        fparams['trial_start_end'] = [-2, 8] # primary visualization window relative to event onset; [start, end] times (in seconds) 
        fparams['flag_normalization'] = 'zscore' # options: 'zscore', 'dff', None
        fparams['specific_baseline'] = False
        fparams['baseline_start_end'] = [-2, -0.2] # baseline window (in seconds) for performing baseline normalization. either a list [start, end] or an int/float (see details in markdown above); I set this to -0.2 to be safe I'm not grabbing a sample that includes the event
        fparams['event_dur'] = 2 # duration of stim/event in seconds; displays a line below main plot indicating event duration
        fparams['event_sort_analysis_win'] = [0, 5] # time window (in seconds) for sorting cells; list [start, end]
        
        # ROI sorting; if flag_sort_rois is set to True, ROIs are sorted by activity in the fparams['event_sort_analysis_win'] window
        fparams['flag_sort_rois'] = False
        if fparams['flag_sort_rois']:
            
            fparams['user_sort_method'] = 'max_value' # peak_time or max_value
            fparams['roi_sort_cond'] = 'plus' # for roi-resolved heatmaps, which condition to sort ROIs by

        fparams['interesting_rois'] = []

        # errorbar and saving figures
        fparams['flag_roi_trial_avg_errbar'] = True # toggle to show error bar on roi- and trial-averaged traces
        fparams['flag_trial_avg_errbar'] = True # toggle to show error bars on the trial-avg traces
        ########## DONE WITH USER-DEFINED VARIABLES; DON'T MODIFY CODE BELOW
        
        if fparams['flag_normalization'].lower() == 'zscore':
            fparams['norm_mode'] = 'zdata'
        elif fparams['flag_normalization'].lower() == 'dff':
            fparams['norm_mode'] = 'dff_data'
        else:
            fparams['norm_mode'] = 'data'
    
    return fparams

fparams = define_params(method = 'single') # options are 'single', 'f2a', 'root_dir'
runtime_params = {}

In [None]:
# simple class to update limits as you go through iterations of data
# first call update_lims(first_lims)
# then update_lims.update(new_lims)
# update_lims.output() outputs lims
class update_lims:
    
    def __init__(self, lims):
        self.lims = lims
        
    
    def update(self, new_lims):
        if self.lims[0] > new_lims[0]:
            self.lims[0] = new_lims[0]
        
        if self.lims[1] < new_lims[1]:
            self.lims[1] = new_lims[1]

    def output(self):
        return self.lims
    
    
# find 2D subplot index based on a numerical incremented index (ie. idx=3 would be (2,1) for a 2x2 subplot figure)     
def subplot_loc(idx, num_rows, num_col):
    if n_rows == 1:
        subplot_index = idx
    else:
        subplot_index = np.unravel_index(idx, (n_rows, int(n_columns))) # turn int index to a tuple of array coordinates
    return subplot_index


def get_cmap(n, name='plasma'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)


def is_all_nans(vector):
    """
    checks if series or vector contains all nans; returns boolean. Used to identify and exclude all-nan rois
    """
    if isinstance(vector, pd.Series):
        vector = vector.values
    return np.isnan(vector).all()

# declare some fixed constant variables
axis_label_size = 15
tick_font_size = 14 

if 'zscore' in fparams['flag_normalization']:
        data_trial_resolved_key = 'zdata'
        data_trial_avg_key = 'ztrial_avg_data'
        cmap_ = None
        ylabel = 'Z-score Activity'
else:
    data_trial_resolved_key = 'data'
    data_trial_avg_key = 'trial_avg_data'
    cmap_ = 'inferno'
    ylabel = 'Activity'

In [None]:
def load_events(fparams_in, fpath):
    
    glob_event_files = glob.glob(fpath) # look for a file in specified directory
    if len(glob_event_files) == 0:
        print(f'{events_file_path} not detected. Please check if path is correct.')
    if 'csv' in glob_event_files[0]:
        event_times = utils.df_to_dict(glob_event_files[0])
    elif any(x in glob_event_files[0] for x in ['pkl', 'pickle']):
        event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
    event_frames = utils.dict_time_to_samples(event_times, fparams_in['fs'])

    return event_times, event_frames

def get_conditions(fparams_in, event_dict):
    
    # identify conditions to analyze
    all_conditions = event_dict.keys()
    conditions = [ condition for condition in all_conditions if len(event_dict[condition]) > 0 ] # keep conditions that have events
    
    conditions.sort()
    if fparams_in['selected_conditions']:
        conditions = fparams_in['selected_conditions']

    return conditions

def calc_event_time_vars(fparams_in):

    dict = {}
    
    dict['trial_start_end_sec'] = np.array(fparams_in['trial_start_end']) # trial windowing in seconds relative to ttl-onset/trial-onset, in seconds
    if type(fparams_in['baseline_start_end']) is list:
        baseline_start_end_sec = np.array(fparams_in['baseline_start_end'])
    elif isinstance(fparams_in['baseline_start_end'], (int, float)):
        baseline_start_end_sec = np.array([dict['trial_start_end_sec'][0], fparams_in['baseline_start_end']])
    
    # convert times to samples and get sample vector for the trial 
    dict['trial_begEnd_samp'] = np.round(dict['trial_start_end_sec']*fparams_in['fs']).astype('int') # turn trial start/end times to samples
    # and for baseline period
    dict['baseline_begEnd_samp'] = np.round(baseline_start_end_sec*fparams_in['fs']).astype('int')
    
    # calculate time vector for plot x axes
    trial_svec = np.arange(dict['trial_begEnd_samp'][0], dict['trial_begEnd_samp'][1])
    num_samples_trial = len( trial_svec )
    dict['tvec'] = np.round(np.linspace(dict['trial_start_end_sec'][0], dict['trial_start_end_sec'][1], num_samples_trial+1), 2)
    
    # find samples and calculations for time 0 for plotting
    dict['t0_sample'] = utils.get_tvec_sample(dict['tvec'], 0) # grabs the sample index of a given time from a vector of times
    dict['event_end_sample'] = int(np.round(dict['t0_sample']+fparams_in['event_dur']*fparams_in['fs']))
    dict['event_bound_ratio'] = [(dict['t0_sample'])/num_samples_trial , dict['event_end_sample']/num_samples_trial] # fraction of total samples for event start and end; only used for plotting line indicating event duration

    return dict

def run_single_session(fparams_in, runtime_params):

    ### declare paths 
    signals_fpath = os.path.join(fparams_in['fdir'], fparams_in['fname_signal'])
    events_file_path = os.path.join(fparams_in['fdir'], fparams_in['fname_events'])
    
    save_dir = os.path.join(fparams_in['fdir'], 'event_rel_analysis')
    
    utils.check_exist_dir(save_dir); # make the save directory
    
    ### create variables that reference samples and times for slicing and plotting the data

    timing_dict = calc_event_time_vars(fparams_in)
    
    ### load data
    signals = utils.load_signals(signals_fpath)
    
    runtime_params['num_rois'] = signals.shape[0]
    runtime_params['all_nan_rois'] = np.where(np.apply_along_axis(is_all_nans, 1, signals)) # find rois with activity as all nans
    
    ### load behavioral data and trial info

    event_times, event_frames = load_events(fparams_in, events_file_path)
    conditions = get_conditions(fparams_in, event_frames)

    cmap_lines = get_cmap(len(conditions))
    
    """
    MAIN data processing function to extract event-centered data
    
    extract and save trial data, 
    saved data are in the event_rel_analysis subfolder, a pickle file that contains the extracted trial data
    """
    data_dict = utils.extract_trial_data(signals, timing_dict['tvec'], timing_dict['trial_begEnd_samp'], event_frames, 
                                         conditions, fparams_in['specific_baseline'], baseline_start_end_samp = timing_dict['baseline_begEnd_samp'], save_dir=save_dir)

    return runtime_params, conditions

## Run single session event-related analysis on all subfolders/sessions in parent directory (fdir_parent)

In [None]:
def collect_files(parent_directory):
    """
    Collect npy and csv files from subdirectories one level down in the given parent directory.
    
    Parameters:
        parent_directory (str): Path to the parent directory.
        
    Returns:
        dict: A dictionary where keys are subdirectory names and values are 
              dictionaries with file paths to 'signals' npy and 'events' csv files.
    """
    result = {}

    # List all immediate subdirectories of the parent directory
    for subdir_name in os.listdir(parent_directory):
        subdir_path = os.path.join(parent_directory, subdir_name)
        
        # Skip if it's not a directory
        if not os.path.isdir(subdir_path):
            continue
        
        signals_file = None
        events_file = None

        # Look for files in the current subdirectory
        for file in os.listdir(subdir_path):
            if file.endswith('.npy'):
                signals_file = os.path.join(subdir_path, file)
                
                if 'signals' not in file:
                    warnings.warn(f"No 'signals' npy file found in {subdir_path}")
            elif file.endswith('.csv'):
                events_file = os.path.join(subdir_path, file)
                
                if 'events' not in file:
                    warnings.warn(f"No 'events' csv file found in {subdir_path}")
            
        
        # Store the results (including partial findings)
        result[subdir_name] = {
            'signals': signals_file,
            'events': events_file
        }
    
    return result

def update_fs(fparams):
    json_fpath = os.path.join(fparams['fdir_parent'], fparams['fname']+".json")
    if os.path.exists(json_fpath):
        json_data = utils.open_json(json_fpath)
        if 'fs' in json_data:
            fparams['fs'] = json_data['fs']
            
with open(os.path.join(fparams['fdir_parent'], 'group_event_rel_analysis.json'), 'w') as fp:
    json.dump(fparams, fp)


In [None]:
file_mapping = collect_files(fparams['fdir_parent'])

# run automatic single-session analysis
for subdir, files in file_mapping.items():
    print(f"{subdir}: {files}")
    
    fparam = fparams.copy() # copy requires or original fparams will modified with subsequent actions
    fparam['fdir'] = os.path.dirname(files['signals'])
    fparam['fname_signal'] = os.path.basename(files['signals'])
    fparam['fname_events'] = os.path.basename(files['events'])

    fparam['fname'] = os.path.split(fparams['fdir_parent'])[1]
    
    runtime_params, conditions = run_single_session(fparam, runtime_params)

## Perform group level analysis after single session analysis has completed

In [None]:
# Function to find all .pkl files containing "event" in their name in a directory and its subdirectories
def find_pkl_files(root_dir):
    """Recursively finds all .pkl files containing 'event' in the given directory."""
    pkl_files = []
    for dirpath, _, filenames in os.walk(root_dir):
        for file in filenames:
            if 'event_data_dict' in file and file.endswith('.pkl'):
                pkl_files.append(os.path.join(dirpath, file))
    return pkl_files

# Function to load and process data from a .pkl file
def load_and_average_data(fparams_in, file_path):
    """Loads a .pkl file and extracts trial-averaged data."""
    print(f"Loading {file_path}")
    with open(file_path, 'rb') as file:
        session_data = pickle.load(file)
    
    processed_data = {}
    for condition, content in session_data.items():
        if 'data' in content:
            data_array = np.array(content[fparams_in['norm_mode']])  # Convert to NumPy array
            trial_avg = np.mean(data_array, axis=0)  # Average across trials
            processed_data[condition] = trial_avg
    return processed_data

# Main function to structure and combine trial-averaged data across sessions
def combine_sessions(fparams_in):
    """Combines trial-averaged data across multiple sessions."""
    all_data = defaultdict(lambda: defaultdict(list))

    pkl_files = find_pkl_files(fparams_in['fdir_parent'])
    for file_idx, file_path in enumerate(pkl_files):
        session_data = load_and_average_data(fparams_in, file_path) # session_data are trial-averaged 
        for condition, trial_avg in session_data.items():
            all_data[condition][file_idx] = trial_avg
    
    # Organize data for flexibility
    structured_data = {}
    for condition, sessions in all_data.items():
        structured_data[condition] = {
            session_id: np.stack(trials, axis=0) if len(trials) > 1 else np.expand_dims(trials[0], axis=0)
            for session_id, trials in sessions.items()
        }  # Ensures compatibility with single-cell data by adding an axis if needed
        
    return structured_data


# calculate all the color limits for heatmaps; useful for locking color limits across different heatmap subplots   
def generate_clims(data_in, norm_type):
    # get min and max for all data across conditions 
    clims_out = [np.nanmin(data_in), np.nanmax(data_in)]
    if 'zscore' in norm_type: # if data are zscored, make limits symmetrical and centered at 0
        clims_max = np.max(np.abs(clims_out)) # then we take the higher of the two magnitudes
        clims_out = [-clims_max*0.5, clims_max*0.5] # and set it as the negative and positive limit for plotting
    return clims_out


In [None]:
# Example usage
if __name__ == "__main__":

    structured_data = combine_sessions(fparams)

    # Example analyses:
    # 1. Average across cells for each session, then collate cell-averaged traces into avg_across_cells
    data_dict_cell_avg = {}
    for condition, sessions in structured_data.items():
        session_data = np.array([np.mean(data_array, axis=0) for data_array in sessions.values()])
        data_dict_cell_avg[condition] = session_data
        print(f"Condition: {condition}, Avg Across Cells Shape (sessions, time): {data_dict_cell_avg[condition].shape}")
    
    # 2. Combine cells across sessions for cumulative averaging
    data_dict_cumulative_cells = {}
    for condition, sessions in structured_data.items():
        # Stack all session arrays along a new axis, then reshape into blocks
        session_arrays = [np.array(data_array) for data_array in sessions.values()]
        cumulative_data = np.vstack(session_arrays)  # Ensures session-wise grouping
        data_dict_cumulative_cells[condition] = cumulative_data
        print(f"Condition: {condition}, Cumulative Cells Shape (total_cells, time): {cumulative_data.shape}")

    
    data_dict = data_dict_cumulative_cells

## Process data combined across sessions

In [None]:
# function to find closest sample when a time occurs in a time vector
tvec2samp = lambda tvec, time: np.argmin(np.abs(tvec - time))

# function to sort ROIs based on activity in certain epoch
def sort_heatmap_peaks(data, tvec, sort_epoch_start_time, sort_epoch_end_time, sort_method = 'peak_time'):
    
    # find start/end samples for epoch
    sort_epoch_start_samp = tvec2samp(tvec, sort_epoch_start_time)
    sort_epoch_end_samp = tvec2samp(tvec, sort_epoch_end_time)
    
    if sort_method == 'peak_time':
        epoch_peak_samp = np.argmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(epoch_peak_samp)
        
    elif sort_method == 'max_value':
        time_max = np.nanmax(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.argsort(time_max)[::-1]
        
    elif sort_method == 'mean_value':
        epoch_peak_samp = np.mean(data[:,sort_epoch_start_samp:sort_epoch_end_samp], axis=1)
        final_sorting = np.flip(np.argsort(epoch_peak_samp))

    return final_sorting

In [None]:
timing_dict = calc_event_time_vars(fparams)

# if flag is true, sort ROIs (usually by average fluorescence within analysis window)
if fparams['flag_sort_rois']:
    if not fparams['roi_sort_cond']: # if no condition to sort by specified, use first condition
        fparams['roi_sort_cond'] = data_dict.keys()[0]
    if not fparams['roi_sort_cond'] in data_dict.keys():
        sorted_roi_order = range(runtime_params['num_rois'])
        interesting_rois = fparams['interesting_rois']
        print('Specified condition to sort by doesn\'t exist! ROIs are in default sorting.')
    else:
        # returns new order of rois sorted using the data and method supplied in the specified window
        sorted_roi_order = sort_heatmap_peaks(data_dict[fparams['roi_sort_cond']], timing_dict['tvec'], 
                           sort_epoch_start_time=0, 
                           sort_epoch_end_time = timing_dict['trial_start_end_sec'][-1], 
                           sort_method = fparams['user_sort_method'])
        # finds corresponding interesting roi (roi's to mark with an arrow) order after sorting
        interesting_rois = np.in1d(sorted_roi_order, fparams['interesting_rois']).nonzero()[0] 
else:
    sorted_roi_order = None
    interesting_rois = fparams['interesting_rois']

if not runtime_params['all_nan_rois'][0].size == 0:
    set_diff_keep_order = lambda main_list, remove_list : [i for i in main_list if i not in remove_list]
    sorted_roi_order = set_diff_keep_order(sorted_roi_order, runtime_params['all_nan_rois'])
    interesting_rois = [i for i in fparams['interesting_rois'] if i not in runtime_params['all_nan_rois']]
    
roi_order_path = os.path.join(fparams['fdir_parent'], fparams['fname'] + '_roi_order.pkl')
with open(roi_order_path, 'wb') as handle:
     pickle.dump(sorted_roi_order, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
def plot_trial_avg_heatmap(data_in, conditions, tvec, event_bound_ratio, cmap, clims, sorted_roi_order = None, 
                           rois_oi = None, save_dir = None, axis_label_size=15):
    
    """
    Technically doesn't need to remove all_nan_rois b/c of nanmean calculations
    """
    
    num_subplots = len(conditions)
    n_columns = np.min([num_subplots, 3.0])
    n_rows = int(np.ceil(num_subplots/n_columns))

    

    fig, ax = plt.subplots(nrows=n_rows, ncols=int(n_columns), figsize = (n_columns*5, n_rows*4))
    if not isinstance(ax,np.ndarray): # this is here to make the code below compatible with indexing a single subplot object
        ax = [ax]

    for idx, cond in enumerate(conditions):

        num_rois = data_in[cond].shape[0]

        # set imshow extent to replace x and y axis ticks/labels (replace samples with time)
        plot_extent = [tvec[0], tvec[-1], num_rois, 0 ]
        
        # determine subplot location index
        if n_rows == 1:
            subplot_index = idx
        else:
            subplot_index = np.unravel_index(idx, (n_rows, int(n_columns))) # turn int index to a tuple of array coordinates

        # prep labels; plot x and y labels for first subplot
        if subplot_index == (0, 0) or subplot_index == 0 :
            ax[subplot_index].set_ylabel('ROI #', fontsize=axis_label_size)
            ax[subplot_index].set_xlabel('Time (s)', fontsize=axis_label_size);
        ax[subplot_index].tick_params(axis = 'both', which = 'major', labelsize = tick_font_size)
        
        # plot the data
        if sorted_roi_order is not None:
            roi_order = sorted_roi_order
        else:
            roi_order = slice(0, num_rois)
        to_plot = data_in[cond][roi_order,:] # 
        
        im = utils.subplot_heatmap(ax[subplot_index], cond, to_plot, cmap=cmap_, clims=clims, extent_=plot_extent)
        ax[subplot_index].axvline(0, color='k', alpha=0.3) # plot vertical line for time zero
        ax[subplot_index].annotate('', xy=(event_bound_ratio[0], -0.01), xycoords='axes fraction', 
                                       xytext=(event_bound_ratio[1], -0.01), 
                                       arrowprops=dict(arrowstyle="-", color='g'))
        if rois_oi is not None:
            for ROI_OI in rois_oi:
                ax[subplot_index].annotate('', xy=(1.005, 1-(ROI_OI/num_rois)-0.015), xycoords='axes fraction', 
                                           xytext=(1.06, 1-(ROI_OI/num_rois)-0.015), 
                                           arrowprops=dict(arrowstyle="->", color='k'))
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    cbar = fig.colorbar(im, ax = ax, shrink = 0.7)
    cbar.ax.set_ylabel(ylabel, fontsize=13)
    
    # hide empty subplot
    if num_subplots > 1:
        for a in ax.flat[num_subplots:]:
            a.axis('off')
    
    if save_dir:
        fig.savefig(os.path.join(save_dir,'trial_avg_heatmap.png')); 
        fig.savefig(os.path.join(save_dir,'trial_avg_heatmap.pdf'));

    return fig, ax


In [None]:
# major change - input into plot_trial_avg_heatmap data_dict does not have a 2nd key tier of normalization type - that will be determined earlier
# it's just data_dict[condition]

fig, ax = plot_trial_avg_heatmap(data_dict, conditions, timing_dict['tvec'], timing_dict['event_bound_ratio'], cmap_,
                       clims = generate_clims(np.concatenate([data_dict[cond].flatten() for cond in conditions]), 
                                              fparams['flag_normalization']),
                       sorted_roi_order = sorted_roi_order, rois_oi = interesting_rois, save_dir = fparams['figs_save_dir'])

#fig.set_size_inches(12, 15)
#ax[0].set_ylim(0,4)

In [None]:
line_shades = []
cmap_lines = get_cmap(len(conditions))

fig, axs = plt.subplots(1,1, figsize = (10,6))
for idx, cond in enumerate(conditions):
    line_color = cmap_lines(idx)

    num_rois = data_dict[cond].shape[0]
    
    # take avg/std across ROIs; data are already trial-avged
    roi_trial_avg = np.nanmean(data_dict[cond], axis=0)
    roi_trial_std = np.nanstd(data_dict[cond], axis=0)
     
    to_plot = np.squeeze(roi_trial_avg)
    to_plot_err = np.squeeze(roi_trial_std)/np.sqrt(num_rois)
    
    axs.plot(timing_dict['tvec'], to_plot, color=line_color)

    
    
    if fparams['flag_roi_trial_avg_errbar']:
        line = axs.plot(timing_dict['tvec'][timing_dict['t0_sample']:timing_dict['event_end_sample']], to_plot[timing_dict['t0_sample']:timing_dict['event_end_sample']], color=line_color)
        shade = axs.fill_between(timing_dict['tvec'], to_plot - to_plot_err, to_plot + to_plot_err, color = line_color,
                     alpha=0.2) # this plots the shaded error bar
        line_shades.append((line[0],shade))
            
axs.set_ylabel(ylabel, fontsize=axis_label_size)
axs.set_xlabel('Time [s]', fontsize=axis_label_size);
axs.legend(conditions);
axs.legend(line_shades, conditions, fontsize=15)
axs.axvline(0, color='0.5', alpha=0.65) # plot vertical line for time zero
axs.annotate('', xy=(timing_dict['event_bound_ratio'][0], -0.01), xycoords='axes fraction', 
                               xytext=(timing_dict['event_bound_ratio'][1], -0.01), 
                               arrowprops=dict(arrowstyle="-", color='g'))
axs.tick_params(axis = 'both', which = 'major', labelsize = tick_font_size+3)
axs.autoscale(enable=True, axis='both', tight=True)

#axs.set_ylim([-1.5, 10])

if fparams['flag_save_figs']:
        fig.savefig(os.path.join(fparams['figs_save_dir'],'roi_trial_avg_trace.png')); fig.savefig(os.path.join(fparams['figs_save_dir'],'roi_trial_avg_trace.pdf'));

In [None]:
# just a test
with open(os.path.join(save_dir, 'event_data_dict - Copy.pkl'), 'rb') as file:
        copy_dat = pickle.load(file)

with open(os.path.join(save_dir, 'event_data_dict.pkl'), 'rb') as file:
        main_dat = pickle.load(file)

print(copy_dat['minus']['data'].shape)

print(main_dat['minus']['data'].shape)

np.array_equal(copy_dat['minus']['data'], main_dat['minus']['data'])