# Bruker raw ome-tiff preparation for preprocessing pipeline

In [None]:
# packages for raw video to h5 processing
import numpy as np
import os
import glob
from scipy import signal
import h5py
import warnings
import multiprocessing as mp

import matplotlib.pyplot as plt
from PIL import Image
from PIL.TiffTags import TAGS
import tifffile as tiff
from lxml.html.soupparser import fromstring
from lxml.etree import tostring
from copy import copy, deepcopy

import bruker_marked_pts_process

# more packages for xml meta and analog input processing
from functools import partial
import xml.etree.ElementTree as ET
import pandas as pd
import pickle
import re

import utils_bruker

In [None]:
### functions for raw video conversion to h5


def uint16_scale(img):
    tmp = img - np.min(img) # shift values such that there are no negatives

    ratio = np.amax(tmp) / 65535.0

    return np.squeeze(tmp/ratio) 


def read_shape_tiff(data_path):
    
    data = uint16_scale(tiff.imread(data_path)).astype('uint16')
    data_shape = data.shape

    return data, data_shape

def get_ometif_xy_shape(fpath):
    # read first tiff to get data shape
    first_tif = tiff.imread(fpath, key=0, is_ome=True)
    return first_tif.shape


def get_tif_meta(tif_path):
    meta_dict = {}
    # iterate through metadata and create dict for key/value pairs
    with Image.open(tif_path) as img:
        for key in img.tag.iterkeys():
            if key in TAGS:
                meta_dict[TAGS[key]] = img.tag[key] 
            else:
                meta_dict[key] = img.tag[key] 
    
    return meta_dict


def check_if_meta_tif(path):
    
    meta_dict = get_tif_meta(path)

    # 'ImageDescription' key contains info about the file(s) in xml format   
    tag_soup = str(meta_dict['ImageDescription'][0][21:])
    root_meta = fromstring(tag_soup) # process xml string

    # the 'image' tag in the xml is unique to the first tif with metadata; check for that
    subdict_image = []
    for neighbor in root_meta.iter('image'):
         subdict_image = neighbor.attrib

    return 'id' in subdict_image



def assert_bruker(fpath):
    meta_dict = get_tif_meta(fpath)
    assert ('Prairie' in meta_dict['Software'][0]), "This is not a bruker file!"
    
    
def load_save_composite_frames(save_object, glob_list, chunked_frame_idx, save_format):
    # go through each chunk, load frames in chunk, process, and append to file
    for idx, chunk_frames in enumerate(chunked_frame_idx):
        print( 'Processing chunk {} out of {} chunks'.format(str(idx+1), str(len(chunked_frame_idx))) )
        start_idx = chunk_frames[0]
        end_idx = chunk_frames[-1]+1

        #loaded_tiffs = uint16_scale(tiff.imread(glob_list[start_idx:end_idx], key=0, is_ome=True))
        data_to_save = tiff.imread(glob_list[start_idx:end_idx], key=0, is_ome=True)

        if save_format == 'tif':

            for frame in tiffs_to_save:
                save_object.save(frame, photometric='minisblack')

        # https://stackoverflow.com/questions/25655588/incremental-writes-to-hdf5-with-h5py
        elif save_format == 'h5':   

            # append data to h5    
            save_object[start_idx:end_idx] = data_to_save

            
def main_ometif_to_composite(fdir, fname, save_type='h5', num_frames=None):

    save_fname = os.path.join(fdir, fname)
    glob_list = glob.glob(os.path.join(fdir,"*.tif"))

    # get frame info
    if not num_frames: # CZ tmp: comment back in once make this into a function
        num_frames = len(glob_list)
    frame_shape = get_ometif_xy_shape(glob_list[0])
    print(str(num_frames) + ' total frame')

    # prepare to split data into chunks when loading to reduce memory imprint
    chunk_size = 10000.0
    n_chunks = int(np.ceil(num_frames/chunk_size))
    chunked_frame_idx = np.array_split(np.arange(num_frames), n_chunks) # split frame indices into chunks

    assert_bruker(glob_list[0])
    print('Processing Bruker data')

    # prepare handles to write data to
    if save_type == 'tif':
        save_object = tiff.TiffWriter(save_fname + '.tif', bigtiff=True)
    elif save_type == 'h5':
        f = h5py.File(save_fname + '.h5', 'w')
        # get data shape and chunk up data, and initialize h5 
        save_object = f.create_dataset('imaging', (num_frames, frame_shape[0], frame_shape[1]), 
                                maxshape=(None, frame_shape[0], frame_shape[1]), dtype='uint16')
        
    load_save_composite_frames(save_object, glob_list, chunked_frame_idx, save_type)
    
    if save_type == 'h5':
        f.close()

### functions for meta data xml processing and analog processing

# load in recording/tseries main xml and grab frame period
def bruker_xml_get_2p_fs(xml_path):
    xml_parse = ET.parse(xml_path).getroot()
    for child in list(xml_parse.findall('PVStateShard')[0]):
        if 'framePeriod' in ET.tostring(child):
            return 1.0/float(child.attrib['value'])

        
# takes bruker xml data, parses for each frame's timing and cycle
def bruker_xml_make_frame_info_df(xml_path):
    xml_parse = ET.parse(xml_path).getroot()
    frame_info_df = pd.DataFrame()
    for idx, type_tag in enumerate(xml_parse.findall('Sequence/Frame')):
        # extract relative and absolute time from each frame's xml meta data
        frame_info_df.loc[idx, 'rel_time'] = float(type_tag.attrib['relativeTime'])
        frame_info_df.loc[idx, 'abs_time'] = float(type_tag.attrib['absoluteTime'])

        # grab cycle number from frame's name
        frame_fname = type_tag.findall('File')[0].attrib['filename']
        frame_info_df.loc[idx, 'cycle_num'] = int(re.findall('Cycle(\d+)', frame_fname)[0])
    return frame_info_df


# loads and parses the analog/voltage recording's xml and grabs sampling rate
def bruker_analog_xml_get_fs(xml_fpath):
    analog_xml = ET.parse(xml_fpath).getroot()
    return float(analog_xml.findall('Experiment')[0].find('Rate').text)


# concatenate the analog input csv files if there are multiple cycles
def bruker_concatenate_analog(fname, fpath):
    # grab all csv voltage recording csv files that aren't the concatenated full
    glob_analog_csv = [f for f in glob.glob(os.path.join(fpath,"*_VoltageRecording_*.csv")) if 'full' not in f]
    glob_analog_xml = glob.glob(os.path.join(fpath,"*_VoltageRecording_*.xml"))

    # xml's contain metadata about the analog csv; make sure sampling rate is consistent across cycles
    analog_xml_fs = set(map(bruker_analog_xml_get_fs, glob_analog_xml)) # map grabs sampling rate across all cycle xmls; set finds all unique list entries  
    if len(analog_xml_fs) > 1: 
          warnings.warn('Sampling rate is not consistent across cycles!')
    else:
        analog_fs = list(analog_xml_fs)[0]
    
    # cycle through analog csvs and append to a dataframe
    analog_concat = pd.DataFrame()
    for cycle_idx, cycle_path_csv in enumerate(glob_analog_csv):

        cycle_df = pd.read_csv(cycle_path_csv)
        num_samples = len(cycle_df['Time(ms)'])
        cycle_df['Time(s)'] = cycle_df['Time(ms)']/1000.0

        cycle_df['cycle_num'] = float(re.findall('Cycle(\d+)', cycle_path_csv)[0]) # get cycle # from filename
        if cycle_idx == 0: # initialize pd dataframe with first cycle's data
            cycle_df['cumulative_time_ms'] = cycle_df['Time(ms)'].values
            analog_concat = cycle_df
        else:
            # since time resets for each cycle (if more than one), calculate cumulative time
            last_cumulative_time = analog_concat['cumulative_time_ms'].iloc[-1]
            cycle_df['cumulative_time_ms'] = cycle_df['Time(ms)'].values + last_cumulative_time + 1 # add 1 so that new cycle's first sample isn't the same as the last cycle's last sample
            analog_concat = analog_concat.append(cycle_df, ignore_index = True)
    
    # clean up column names
    analog_concat.columns = analog_concat.columns.str.strip().str.lower().str.replace(' ', '_').str.replace('(', '_').str.replace(')', '')
        
    # loop through all analog columns and get the diff and threshold for event onsets
    analog_column_names = [column for column in analog_concat.columns if 'input' in column]
    num_analogs = len(analog_column_names)   
    for analog_column_name in analog_column_names:
        analog_concat[analog_column_name + '_diff'] = np.append(np.diff(analog_concat[analog_column_name]) > 0.07, 
                                                                False) # add a false to match existing df length

    # save concatenated analog csv        
    save_full_csv_path = os.path.join(fpath, fname + '_VoltageRecording_full.csv')
    analog_concat.to_csv(save_full_csv_path, index=False)

    return analog_concat


# function for finding the index of the closest entry in an array to a provided value
def find_nearest_idx(array, value):

    if isinstance(array, pd.Series):
        idx = (np.abs(array - value)).idxmin()
        return idx, array.index.get_loc(idx), array[idx] # series index, 0-relative index, entry value
    else:
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return idx, array[idx]


### Take in analog dataframe (contains analog tseries and thresholded boolean) and make dict of 2p frame times for each condition's event
def match_analog_event_to_2p(imaging_info_df, analog_dataframe, rename_ports = None, flag_multicondition_analog = False): 

    analog_event_dict = {} # will contain analog channel names as keys and 2p imaging frame numbers for each event/ttl onset
    analog_event_samples = {}
    all_diff_columns = [diff_column for diff_column in analog_df.columns if 'diff' in diff_column] # grab all diff'd analog column names

    for ai_diff in sorted(all_diff_columns):
        
        # if user gives ports to rename, grab port data name
        if rename_ports:
            ai_port_num = int(re.findall('\d+', ai_diff )[0])
            ai_name = rename_ports[ai_port_num]
        else:
            ai_name = ai_diff
        
        if flag_multicondition_analog: # if the trials in analog ports need to be split up later, make a subdict to accommodate conditions keys
            analog_event_dict[ai_name] = {}; analog_event_dict[ai_name]['all'] = []
            analog_event_samples[ai_name] = {}; analog_event_samples[ai_name]['all'] = []
        else:
            analog_event_dict[ai_name] = []
            analog_event_samples[ai_name] = [] 
            
        # grab analog samples where TTL onset occurred
        # analog_df diff columns are booleans for each frame that indicate if TTL threshold crossed (ie. event occurred)
        analog_events = analog_df.loc[analog_df[ai_diff] == True, ['time_s', 'cycle_num']] 

        # for each detected analog event, find nearest 2p frame index and add to analog event dict
        
        for idx, analog_event in analog_events.iterrows():

            this_cycle_imaging_info = imaging_info_df[imaging_info_df['cycle_num'] == analog_event['cycle_num']]
            
            whole_session_idx, cycle_relative_idx, value = find_nearest_idx(this_cycle_imaging_info['rel_time'], analog_event['time_s'])

            if flag_multicondition_analog:
                analog_event_dict[ai_name]['all'].append(whole_session_idx)
                analog_event_samples[ai_name]['all'].append(idx)
            else:
                analog_event_dict[ai_name].append(whole_session_idx)
                analog_event_samples[ai_name].append(idx)

    return analog_event_dict, analog_event_samples

    
# if all behav events of interest (different conditions) are recorded on a single AI channel
# and need to reference the behavioral events csv to split conditions up
def split_analog_channel(ai_to_split, fdir, behav_fname, behav_event_key_path, analog_event_dict):

    unicode_to_str = lambda x:str(x) # just a simple function to convert unicode to string; 

    this_ai_to_split = [analog_diff_name for analog_diff_name in analog_event_dict.keys() if str(ai_to_split) in analog_diff_name][0]
    
    # load id's and samples (camera samples?) of behavioral events (output by behavioral program)
    behav_df = pd.read_csv(os.path.join(fdir, behav_fname), names=['id', 'sample'])
    behav_event_keys = pd.read_excel(behav_event_key_path)

    # using the behav event id, grab the event name from the keys dataframe; names are in unicode, so have to convert to string
    behav_name_of_interest = map(unicode_to_str, 
                                 behav_event_keys[behav_event_keys['event_id'].isin(behav_id_of_interest)]['event_desc'].values)

    # go into ordered behavioral event df, grab the trials with condition IDs of 'behav_id_of_interest' in order
    trial_ids = list(behav_df[behav_df['id'].isin(behav_id_of_interest)]['id'].values) # grab 101, 102, 103 trials in order
    
    # loop through behav conditions, and separate event times for the conglomerate event times in analog_event_dict
    for behav_event_id, behav_event_name in zip(behav_id_of_interest, behav_name_of_interest):
        this_event_idxs = [idx for idx,val in enumerate(trial_ids) if val==behav_event_id]
        analog_event_dict[this_ai_to_split][behav_event_name] = [analog_event_dict[this_ai_to_split]['all'][idx] for idx in this_event_idxs]
        # analog_event_dict ultimately contains 2p frame indices for each event categorized by condition

    # save preprocessed behavioral event data
    with open(behav_analog_save_path, 'wb') as handle:
        pickle.dump(analog_event_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)


# take in data from an analog input and plot detected ttls
def plot_analog_validation(AI_onsets, analog_tseries, analog_fs, save_dir = None):
    # following is just for visualizing ttls; here make tiles for indexing and extracting ttl data in trial manner
    num_AI = len(AI_onsets)
    rel_ind_vec = np.arange(-0.5*analog_fs, 3*analog_fs, 1)
    rel_ind_tile = np.tile(rel_ind_vec, (num_AI,1))
    AI_onset_tile = np.tile(AI_onsets, (len(rel_ind_vec),1)).T

    # extract analog values across flattened trial indices, get values of series, then reshape to 2d array
    AI_value_tile = analog_tseries[np.ndarray.flatten(AI_onset_tile + rel_ind_tile)].values.reshape(AI_onset_tile.shape)
    if AI_value_tile.shape[0] == num_AI:
        AI_value_tile = AI_value_tile.T
    
    fig, ax = plt.subplots(1,3,figsize=(17,5))

    ax[0].set_title('Full TTL series')
    ax[0].plot(analog_tseries)

    ax[1].set_title('{} ttls detected'.format(num_AI))
    ax[1].plot( AI_value_tile );
    ax[1].set_xlabel('Time (ms)')
    ax[1].set_ylabel('Volts');

    svec = np.arange(0, 15*analog_fs)
    tvec_plot = svec/analog_fs
    ax[2].set_title('Specific window (first 15s)')
    ax[2].plot(tvec_plot, analog_tseries[svec])
    ax[2].set_xlabel('Seconds')
    
    if save_path:
        utils.check_exist_dir(save_dir)
        fig.savefig(os.path.join(save_dir, 'ttl_validation.png'));
        
# functions for analyzing bruker analog 

def make_imaging_info_df(bruker_tseries_xml_path):
    xml_parse = ET.parse(bruker_tseries_xml_path).getroot()
    frame_info_df = pd.DataFrame()
    type_tags = xml_parse.findall('Sequence/Frame')

    # lambda function to take in a list of xml frame meta data and pull out timing and cycle info 
    grab_2p_xml_frame_time = lambda type_tag: [float(type_tag.attrib['relativeTime']), 
                                               float(type_tag.attrib['absoluteTime']),
                                               int(re.findall('Cycle(\d+)', type_tag.findall('File')[0].attrib['filename'])[0]) # first grab this frame's file name, then use regex to grab cycle number in the fname
                                              ] 

    # make a dataframe of relative time, absolute time, cycle number for each frame
    imaging_info_df = pd.DataFrame(map(grab_2p_xml_frame_time, type_tags), columns=['rel_time', 'abs_time', 'cycle_num'])

    return imaging_info_df

In [None]:
"""

User-defined variables

"""

def define_params(method = 'single'):
    
    fparams = {}
    
    if method == 'single':
        
        fparams = [
            {
                'fname': 'vj_ofc_imageactivate_001_2020903-001',   # 
                'fdir': r'D:\bruker_data\vj_ofc_imageactivate_001_20200903\vj_ofc_imageactivate_001_2020903-001', #  
                'save_type': 'h5',
                'number_frames': None, # optional; number of frames to analyze; defaults to analyzing whole session (None)
                
                'flag_bruker_analog': True, # set to true if analog/voltage input signals are present and are of interest
                'flag_bruker_stim': True
            }
        ]
        
    elif method == 'f2a': # if string is empty, load predefined list of files in files_to_analyze_event

        fparams = files_to_analyze_prepreprocess.define_fparams()

    elif method == 'root_dir':
        
        pass
    
    return fparams

In [None]:
def single_file_process(fparams):
    
    main_ometif_to_composite(fparams['fdir'], fparams['fname'], fparams['save_type'], num_frames=fparams['number_frames'])


In [None]:

fparams = define_params(method = 'single') # options are 'single', 'f2a', 'root_dir'

num_files = len(fparams)
if num_files == 0:
    raise Exception("No files to analyze!")
print(str(num_files) + ' files to analyze')

# determine number of cores to use and initialize parallel pool
num_processes = min(mp.cpu_count(), num_files)
print('Total CPU cores for parallel processing: ' + str(num_processes))
pool = mp.Pool(processes=num_processes)

# perform parallel processing; pass iterable list of file params to the analysis module selection code
#pool.map(single_file_process, fparams)

## for testing
for fparam in fparams:
    single_file_process(fparam) 




# Meta & Behavioral Data Preprocessing

### Define variables and load 2p recording xml

In [None]:
fname = 'vj_ofc_imageactivate_001_2020903-001'
fdir = r'D:\bruker_data\vj_ofc_imageactivate_001_20200903\vj_ofc_imageactivate_001_2020903-001'

analog_names = ['stim', 'frames', 'licks', 'rewards']

flag_bruker_analog = True # set to true if analog/voltage input signals are present and are of interest
flag_bruker_stim = True

flag_multicondition_analog = False # if a single analog port contains multiple conditions that need to be split up, set to true 
behav_id_of_interest = [101,102,103]
ai_to_split = 2 # int, analog port number that contains TTLs of multiple conditions; events here will be split into individual conditions if flag_multicondition_analog is set to true

validation_plots = False # set to true if want to plot traces of ttl pulses for visualizing and validating
valid_plot_channel = 'input_2' # analog dataframe column names get cleaned up; AI's are "input_#"


In [None]:
def bruker_analog_define_paths(fdir, fname):
    paths_dict = {}
    paths_dict['bruker_tseries_xml_path'] = os.path.join(fparams['fdir'], fparams['fname'] + '.xml') # recording/tseries main xml
    paths_dict['glob_analog_csv'] = glob.glob(os.path.join(fparams['fdir'],"*_VoltageRecording_*.csv")) # grab all analog/voltage recording csvs
    paths_dict['glob_analog_xml'] = glob.glob(os.path.join(fparams['fdir'],"*_VoltageRecording_*.xml")) # grab all analog/voltage recording xml meta
    # behavioral event identification files
    paths_dict['behav_fname'] = fname + '_taste_reactivity.csv' # csv containing each behav event and corresponding sample
    paths_dict['behav_event_key_path'] = r'D:\bruker_data\Adam\key_event.xlsx' # location of excel matching event names and id's
    # define save paths
    paths_dict['behav_save_path'] = os.path.join(fparams['fdir'], 'framenumberforevents_{}.pkl'.format(fparams['fname']) )
    paths_dict['behav_analog_save_path'] = os.path.join(fparams['fdir'], 'framenumberforevents_analog_{}.pkl'.format(fparams['fname']) )
     
    return paths_dict

In [None]:
def main_bruker_analog(fparams):
    # define file paths and output file names
    paths_dict = bruker_analog_define_paths(fdir, fname)
    
    # get more timing meta data about 2p from xmls
    meta_2p_dict = {}
    meta_2p_dict['fs_2p'] = bruker_xml_get_2p_fs(paths_dict['bruker_tseries_xml_path'])
    meta_2p_dict['tvec_2p'] = imaging_info_df['rel_time']
    meta_2p_dict['num_frames_2p'] = len(tvec_2p)
    # 1) Parse main time-series xml, 2) extract frame timing and cycle info into a pandas dataframe 
    meta_2p_dict['imaging_info_df'] = make_imaging_info_df(bruker_tseries_xml_path)
    
    
    """
    If you have analog signals, that indicate behavioral event onset, sent from your behavioral DAQ to the bruker GPIO box, the following code:

    1) parses the analog voltage recording xmls 
    2) extracts the signals from the csvs
    3) extracts the TTL onset times
    4) and finally lines up which frame the TTL occurred on.
    """
    if flag_bruker_analog: # inputs: fname, fdir, imaging_info_df, analog_df, analog_names
        # valid_plot_channel, ai_to_split, 
        # defined above: behav_fname, glob_analog_xml, behav_event_key_path
        
        ### get analog data sampling rate from xml
        analog_fs = bruker_analog_xml_get_fs(glob_analog_xml[0])

        ### either load concatenated voltage recording (across cycles), perform the concatenation, or load a single CSV (for single cycle)
        volt_rec_full_path = os.path.join(fdir, fname + '_VoltageRecording_full.csv')
        if os.path.exists(volt_rec_full_path): # if a trial-stitched voltage recording was previously saved
            analog_df = pd.read_csv(volt_rec_full_path)
        else:
            analog_df = bruker_concatenate_analog(fname, fdir) 

        ### match analog ttl event onsets to the corresponding 2p frame (for each event in each analog port)
        analog_event_dict, analog_event_samples = match_analog_event_to_2p(imaging_info_df, analog_df, rename_ports = analog_names)

        ### if there are multiple conditions signaled on a single analog port, split them up, resave as pickle
        if flag_multicondition_analog:
            split_analog_channel(ai_to_split, fdir, behav_fname, behav_event_key_path, analog_event_dict) 

        if validation_plots:
            valid_save_dir = os.path.join(fdir, fname+'_output_images')
            utils_bruker.check_exist_dir(valid_save_dir)
            plot_analog_validation(analog_event_samples, analog_df[valid_plot_channel], 
                                   analog_fs, valid_save_dir);

        # save preprocessed behavioral event data
        with open(behav_save_path, 'wb') as handle:
            pickle.dump(analog_event_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
    if flag_bruker_stim:
        bruker_marked_pts_process.main_detect_save_stim_frames(fdir, fname, detection_threshold=1.5, flag_plot_mk_pts=False)

### Load and process analog voltage recordings

## load behav data and ID event onset 2p frames

Alternatively, if you have a separate event recorder that is synchronized to the 2p microscope (via frame onset TTL from the GPIO output), you can use the following code.

input: fdir, fname, behav_fname, 

In [None]:
behav_fs = 1000.0 # sampling rate of behavioral csv

In [None]:
# load id's and samples (camera samples?) of behavioral events (output by behavioral program)
behav_df = pd.read_csv(os.path.join(fdir, behav_fname), names=['id', 'sample'])
behav_event_keys = pd.read_excel(behav_event_key_path)

In [None]:
# behav camera pulses are synced to 2p frames. To synchronize event times with the 2p frames, need to normalize to the first 
# camera frame.
try:
    camera_pulse_event_id = behav_event_keys['event_id'][behav_event_keys['event_desc'] == 'camera pulse'].values[0]
    first_cam_pulse_sample = behav_df[behav_df['id'] == camera_pulse_event_id].iloc[0]['sample']
except:
    print('No camera pulse events!')

In [None]:
frame_events_dict = {}

# loop through each type of event
for idx, row in behav_event_keys.iterrows():
    
    # grab event id
    this_id_name = str(row['event_desc'])
    # grab rows of behav dataframe with this event's id
    this_id_rows = behav_df['id'].isin([row['event_id']])

    # convert to seconds
    event_times_seconds = (behav_df[this_id_rows]['sample'].values-first_cam_pulse_sample)/behav_fs 
    # first_cam_pulse_sample subtracted to zero times relative to first camera frame
    
    # using zero'd event times in seconds, find closest 2p frame sample index
    frame_events_dict[this_id_name] = map(partial(find_nearest_idx, tvec_2p), event_times_seconds)
    
# save preprocessed behavioral event data
with open(behav_save_path, 'wb') as handle:
    pickle.dump(frame_events_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Run this is if you performed opto stim

## Plot ROI calcium trace  with TTLs

In [None]:
import os
import numpy as np
import glob
import pickle
import seaborn as sns
import matplotlib.ticker as ticker
import pandas as pd
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from __future__ import division # make py2 act like py3 where int division turns into float
import matplotlib
#important for text to be detected when importing saved figures into illustrator
matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
plt.rcParams["font.family"] = "Arial"

import sys
sys.path.insert(0, r"C:\Users\stuberadmin\Documents\GitHub\NAPE_imaging_analysis\in_development")
import utils

import matplotlib.pyplot as plt

In [None]:

def define_params(method = 'single'):
    
    fparams = {}
    
    if method == 'single':
        
        fparams['fname'] = 'vj_ofc_imageactivate_001_20200828-003' # 
        fparams['fdir'] = r'D:\bruker_data\vj_ofc_imageactivate_001_20200828\vj_ofc_imageactivate_001_20200828-003' #  
    
        # set the sampling rate
        fparams['fs'] = 15
        #if os.path.join(fparams['fdir'], ):
        #    fparams['fs'] = 

        # trial windowing 
        fparams['trial_start_end'] = [-2, 5]
        fparams['baseline_end'] = -0.2
        fparams['event_dur'] = 0.46#0.46 # duration of stim/event in seconds

        # session info
        fparams['opto_blank_frame'] = True
        
        # analysis and plotting arguments
        fparams['flag_npil_corr'] = True # declare which data to load in
        fparams['flag_zscore'] = True # whether or not to z-score data for plots
        
        # ROI sorting 
        fparams['flag_sort_rois'] = False
        if fparams['flag_sort_rois']:
            fparams['user_sort_method'] = 'max_value' # peak_time or max_value
            fparams['roi_sort_cond'] = 'slm_stim' # for roi-resolved heatmaps, which condition to sort ROIs by
            
        # errorbar and saving figures
        fparams['flag_roi_trial_avg_errbar'] = True # toggle to show error bar on roi- and trial-averaged traces
        fparams['flag_trial_avg_errbar'] = True # toggle to show error bars on the trial-avg traces
        fparams['flag_save_figs'] = True
        fparams['interesting_rois'] = [] #[ 0, 1, 2, 23, 22, 11, 9, 5, 6, 7, 3, 4, 8, 12, 14, 15, 16, 17] # [35, 30, 20, 4] #
    
    elif method == 'f2a': # if string is empty, load predefined list of files in files_to_analyze_event

        fparams = files_to_analyze_event.define_fparams()

    elif method == 'root_dir':
        
        pass
    
    return fparams

fparams = define_params(method = 'single') # options are 'single', 'f2a', 'root_dir'

In [None]:
if fparams['flag_npil_corr'] == True:
    signals_fpath = os.path.join(fparams['fdir'], "{}_neuropil_corrected_signals*".format(fparams['fname']))
    
else:
    signals_fpath = os.path.join(fparams['fdir'], "*_extractedsignals*")

In [None]:
# load time-series data
glob_signal_files = glob.glob(signals_fpath)
if len(glob_signal_files) == 1:
    signals = np.squeeze(np.load(glob_signal_files[0]))
else:
    print('Warning: No or multiple signal files detected; using first detected file')

total_samples = len(signals[0,:])
num_rois = signals.shape[0]
full_tvec = np.linspace(0, total_samples/fparams['fs'], total_samples)

analog_tvec = analog_df['time_s']

In [None]:

iROI = 1

fig_combine, ax_combine = plt.subplots(1,1, figsize=(13,7))
ax_combine.plot(tvec_2p, signals[iROI,:])

ax_combine.plot(analog_tvec, analog_df['input_3'].values*400)
ax_combine.plot(analog_tvec, analog_df['input_2'].values*30)
ax_combine.plot(analog_tvec, analog_df['input_0'].values*1000)
ax_combine.legend(['Activity', 'reward', 'licks', 'stim'], fontsize=15)

ax_combine.set_title('ROI {}'.format(iROI), fontsize=20)
ax_combine.set_ylabel('Fluorescence', fontsize=20)
ax_combine.set_xlabel('Time (s)', fontsize=20)
ax_combine.tick_params(axis = 'both', which = 'major', labelsize = 13)

from matplotlib.ticker import FormatStrFormatter
ax_combine.xaxis.set_major_formatter(FormatStrFormatter('%.3f'))
#ax_combine.set_xlim([12.5, 14.5])
#ax_combine.set_ylim([0, 2200])

ax_combine.set_xlim([119, 126]) # reward
ax_combine.set_xlim([68, 75]) # reward
ax_combine.set_xlim([243, 246]) # reward