In [17]:
import scipy.io
import numpy as np
import os
from typing import Dict, Tuple, List, Union, Optional
import pandas as pd
import warnings

In [18]:
def loadmat(filename: str) -> Dict:
    '''
    Load a .mat file and convert all mat-objects to nested dictionaries.

    Parameters:
    filename (str): The name of the .mat file to load.

    Returns:
    dict: A dictionary containing the contents of the .mat file.
    '''
    # Load the .mat file
    data = scipy.io.loadmat(filename, simplify_cells=True)
    return data

def import_bpod_data_files(input_path: str) -> Tuple[Dict[int, Dict], int, List[str], List[str]]:
    '''
    Load all '.mat' files in a given folder and convert them to Python format.

    Parameters:
    input_path (str): The path to the folder containing the '.mat' files.

    Returns:
    Tuple[Dict[int, Dict], int, list, list]: A tuple containing the converted data, the number of sessions,
    the list of file paths, and the list of file dates.
    '''
    # Get a list of all files in the input path
    behav_path = sorted(os.listdir(input_path))
    behav_data = {}  # Set up file dictionary
    session_dates = []
    sessions = 0  # For naming each data set within the main dictionary

    # Loop through each file in the input path
    for file in [f for f in behav_path if f.endswith('.mat') and os.stat(input_path + f).st_size > 200000]:
        # Check if the file is not the weird hidden file
        if file != '.DS_Store':
            # Load the '.mat' file and add it to the dictionary
            current_file = loadmat(input_path + file)
            behav_data[sessions] = current_file
            sessions += 1
            session_dates.append(file[-19:-4])

    return behav_data, sessions, behav_path, session_dates


In [19]:
## All other helper functions used below

def extract_poke_times(behavior_data: Dict) -> Tuple[List, List, List]:
    """
    Extracts all port in/out times across the session for each port. 
    It aligns them to trial start timestamps so that the port in times 
    are across the whole session.

    Parameters:
    behavior_data (dict): The dictionary containing behavior data for the session.

    Returns:
    Tuple: Lists of all port in times, port out times, and corresponding port references.
    """
    # Initialize lists to store port in times, port out times and corresponding port references
    all_port_in_times = []
    all_port_out_times = []
    all_port_references = []

    # Iterate over each port
    for port in range(1, 9):

        # Initialize lists to store port in/out times for each port
        port_in_times = []
        port_out_times = []

        # Iterate over each trial
        for trial_index in range(behavior_data['SessionData']['nTrials']):

            # Extract port in times
            if f'Port{port}In' in behavior_data['SessionData']['RawEvents']['Trial'][trial_index]['Events']:
                trial_start_timestamp = behavior_data['SessionData']['TrialStartTimestamp'][trial_index]
                port_in_ts_offset = behavior_data['SessionData']['RawEvents']['Trial'][trial_index]['Events'][f'Port{port}In']
                port_in_ts = trial_start_timestamp + port_in_ts_offset

                # If port in timestamp is a single value, convert it to a list
                if isinstance(port_in_ts, np.float64):
                    port_in_ts = [port_in_ts]

                # Add port in times to the list
                port_in_times.extend(port_in_ts)

            # Extract port out times
            if f'Port{port}Out' in behavior_data['SessionData']['RawEvents']['Trial'][trial_index]['Events']:
                trial_start_timestamp = behavior_data['SessionData']['TrialStartTimestamp'][trial_index]
                port_out_ts_offset = behavior_data['SessionData']['RawEvents']['Trial'][trial_index]['Events'][f'Port{port}Out']
                port_out_ts = trial_start_timestamp + port_out_ts_offset

                # If port out timestamp is a single value, convert it to a list
                if isinstance(port_out_ts, np.float64):
                    port_out_ts = [port_out_ts]

                # Add port out times to the list
                port_out_times.extend(port_out_ts)

        # Check if the number of port in times and port out times are equal
        # If not, apply error check and fix
        if len(port_in_times) != len(port_out_times):
            port_in_times, port_out_times = error_check_and_fix(port_in_times, port_out_times)

        # Add port in times, port out times and port references to the overall lists
        all_port_references.extend([port] * len(port_in_times))
        all_port_in_times.extend(port_in_times)
        all_port_out_times.extend(port_out_times)

    return all_port_in_times, all_port_out_times, all_port_references

def error_check_and_fix(port_in_times: List, port_out_times: List) -> Tuple[List, List]:
    """
    Checks and corrects mismatches in the length of port in and port out times lists.
    If lengths are unequal, 'nan' is inserted at the appropriate position or appended to the shorter list.

    Parameters:
    port_in_times (List): The list of port in times.
    port_out_times (List): The list of port out times.

    Returns:
    Tuple: The corrected port in times and port out times lists.
    """
    # Initialize fixed flag as False
    fixed = False

    # If the lengths of port in times and port out times lists are not equal
    if len(port_in_times) != len(port_out_times):

        # If port in times list is longer than port out times list
        if len(port_in_times) > len(port_out_times):
            # Iterate over each item in the port out times list
            for i in range(len(port_out_times)):
                # If the port out time is later than the next port in time
                if port_out_times[i] >= port_in_times[i+1]:
                    # Insert a 'nan' at this position in the port out times list
                    port_out_times.insert(i, 'nan')
                    fixed = True

            # If the issue wasn't fixed by the above process, append 'nan' to port out times list
            if len(port_in_times) > len(port_out_times) and not fixed:
                port_out_times.append('nan')

        # If port out times list is longer than port in times list
        elif len(port_out_times) > len(port_in_times):
            # Iterate over each item in the port in times list
            for i in range(len(port_in_times)):
                # If the port in time is later than or equal to the port out time
                if port_in_times[i] >= port_out_times[i]:
                    # Insert a 'nan' at this position in the port in times list
                    port_in_times.insert(i, 'nan')
                    fixed = True

            # If the issue wasn't fixed by the above process, append 'nan' to port in times list
            if len(port_out_times) > len(port_in_times) and not fixed:
                port_in_times.append('nan')

    # If the lengths of port in times and port out times lists are still not equal
    if len(port_in_times) != len(port_out_times):
        print('Dropped event not fixed!!!!')

    return port_in_times, port_out_times

def remove_dropped_in_events(port_in_times: List, port_out_times: List, port_references: List) -> Tuple[List, List, List]:
    """
    Cleans up the data by removing 'nan' values from the lists of port in times, port out times, and port references.

    Parameters:
    port_in_times (List): The list of port in times.
    port_out_times (List): The list of port out times.
    port_references (List): The list of port references.

    Returns:
    Tuple: The cleaned port in times, port out times, and port references lists.
    """

    # Create a reversed list of indexes to remove in order to avoid index shifting during removal
    indexes_to_remove = [i for i, time in enumerate(port_in_times) if time == 'nan'][::-1]

    for index in indexes_to_remove:
        # Remove 'nan' entries from each list
        del port_in_times[index]
        del port_out_times[index]
        del port_references[index]

    return port_out_times, port_in_times, port_references

def sort_by_time(port_in_times: List, port_out_times: List, port_references: List) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Sorts the data by port in times. If an out time is missing, it will be appended with 'nan'.

    Parameters:
    port_in_times (List): The list of port in times.
    port_out_times (List): The list of port out times.
    port_references (List): The list of port references.

    Returns:
    Tuple: The sorted port in times, port out times, and port references.
    """
    
    # Get the indices that would sort the in times
    sort_indices = np.argsort(port_in_times)

    # Apply the sorted indices to each list and convert them to numpy arrays
    sorted_in_times = np.array(port_in_times, dtype=float)[sort_indices]
    sorted_references = np.array(port_references)[sort_indices]
    
    # Check if the number of out times matches the number of sorted indices
    if len(sort_indices) == len(port_out_times):
        sorted_out_times = np.array(port_out_times, dtype=float)[sort_indices]
    else:
        # If they don't match, append a 'nan' to the out times before sorting
        sorted_out_times = np.array(port_out_times + [np.nan], dtype=float)[sort_indices]

    return sorted_in_times, sorted_out_times, sorted_references

def extract_reward_timestamps(behavior_data: Dict) -> List[float]:
    '''
    Extracts all reward timestamps across a session for each port.

    Parameters:
    behavior_data (Dict): The behavioral data dictionary.

    Returns:
    List[float]: A list containing all the reward timestamps for the session.
    '''
    # Initialize list to store all reward timestamps
    reward_timestamps = []

    # Iterate over each trial in the session
    for trial in range(behavior_data['SessionData']['nTrials']):
        
        # Check if the 'Reward' event exists in the trial data
        if 'Reward' in behavior_data['SessionData']['RawEvents']['Trial'][trial]['States']:
            
            # Calculate the timestamp of the reward event relative to the start of the trial
            trial_start_timestamp = behavior_data['SessionData']['TrialStartTimestamp'][trial]
            reward_time_offset = behavior_data['SessionData']['RawEvents']['Trial'][trial]['States']['Reward'][0]
            
            # Convert the reward timestamp to the session timeline
            reward_timestamp = trial_start_timestamp + reward_time_offset
            
            # Add the reward timestamp to the list of reward timestamps
            reward_timestamps.append(reward_timestamp)

    return reward_timestamps


def find_rewarded_event_indices(sorted_in_timestamps: List[float], 
                                sorted_port_references: List[int], 
                                reward_timestamps: List[float]) -> List[int]:
    '''
    Identifies the indices of rewarded events.

    Parameters:
    sorted_in_timestamps (List[float]): List of sorted poke in timestamps.
    sorted_port_references (List[int]): List of port references corresponding to the poke in timestamps.
    reward_timestamps (List[float]): List of reward timestamps.

    Returns:
    List[int]: Indices of rewarded events in sorted_in_timestamps and sorted_port_references.
    '''

    rewarded_event_indices = []  # Initialize the list to store indices of rewarded events
    reward_index = 0  # Initialize reward index counter

    # Iterate over sorted port references with their indices
    for event_index, port_number in enumerate(sorted_port_references):
        
        # Check if port number is 7 and there are reward timestamps left to process
        if port_number == 7 and reward_index < len(reward_timestamps):
            
            # Skip NaN timestamps
            while np.isnan(reward_timestamps[reward_index]):
                reward_index += 1

                # If there are no more reward timestamps, exit the loop
                if reward_index >= len(reward_timestamps):
                    break

            # If there are still reward timestamps left, check if the in time is greater than or equal to the current reward timestamp
            if reward_index < len(reward_timestamps) and sorted_in_timestamps[event_index] >= reward_timestamps[reward_index]:
                
                # If so, record the event index as a rewarded event
                rewarded_event_indices.append(event_index)
                
                # And move on to the next reward timestamp
                reward_index += 1

    return rewarded_event_indices

def align_trigger_to_index(triggers: List[float], 
                           trigger_indices: List[int], 
                           all_timestamps: List[float]) -> List[Union[float, str]]:
    '''
    Aligns triggers to their corresponding indices in the timestamp array.

    Parameters:
    triggers (List[float]): List of trigger timestamps.
    trigger_indices (List[int]): List of indices corresponding to trigger timestamps.
    all_timestamps (List[float]): List of all timestamps.

    Returns:
    List[Union[float, str]]: Array with triggers aligned to their indices, 
                              and 'NaN' for all other indices.
    '''

    # Initialize output array with 'NaN' for all indices
    aligned_triggers = ['NaN'] * len(all_timestamps)
    
    # Assign trigger values to their corresponding indices
    for trigger_value, trigger_index in zip(triggers, trigger_indices):
        aligned_triggers[trigger_index] = trigger_value

    return aligned_triggers

def extract_trial_timestamps(behavior_data):
    """
    Extracts trial timestamps from behavioral data.

    Args:
        behavior_data: The complete behavioral data dictionary.

    Returns:
        A list of trial timestamps.
    """
    trial_timestamps = []
    for trial in range(behavior_data['SessionData']['nTrials']):
        trial_start_timestamp = behavior_data['SessionData']['TrialStartTimestamp'][trial]
        trial_timestamps.append(trial_start_timestamp)
    return trial_timestamps


def extract_trial_end_times(behavior_data):
    """
    Extracts trial end times from behavioral data.

    Args:
        behavior_data: The complete behavioral data dictionary.

    Returns:
        A list of trial end times.
    """

    all_end_times = []
    for trial in range(behavior_data['SessionData']['nTrials']):
        if 'ExitSeq' in behavior_data['SessionData']['RawEvents']['Trial'][trial]['States']:
            trial_start_timestamp = behavior_data['SessionData']['TrialStartTimestamp'][trial]
            exit_time_offset = behavior_data['SessionData']['RawEvents']['Trial'][trial]['States']['ExitSeq'][-1]
            end_times = trial_start_timestamp + exit_time_offset
            all_end_times.append(end_times)
    return all_end_times

    trial_timestamps = []
    for trial in range(behavior_data['SessionData']['nTrials']):
        trial_start_timestamp = behavior_data['SessionData']['TrialStartTimestamp'][trial]
        trial_timestamps.append(trial_start_timestamp)
    return trial_timestamps


def extract_trial_end_times(behavior_data):
    """
    Extracts trial end times from behavioral data.

    Args:
        behavior_data: The complete behavioral data dictionary.

    Returns:
        A list of trial end times.
    """

    all_end_times = []
    for trial in range(behavior_data['SessionData']['nTrials']):
        if 'ExitSeq' in behavior_data['SessionData']['RawEvents']['Trial'][trial]['States']:
            trial_start_timestamp = behavior_data['SessionData']['TrialStartTimestamp'][trial]
            exit_time_offset = behavior_data['SessionData']['RawEvents']['Trial'][trial]['States']['ExitSeq'][-1]
            end_times = trial_start_timestamp + exit_time_offset
            all_end_times.append(end_times)
        else:
            all_end_times.append('NaN')

    return all_end_times

def determine_trial_id(sorted_port_in_times: np.ndarray, trial_end_timestamps: List[float]) -> List[int]:
    """
    Determines the trial id for each port event.

    Args:
        sorted_port_in_times: Sorted numpy array of times when a port event starts.
        trial_end_timestamps: List of times when each trial ends.

    Returns:
        A list of trial ids, one for each port event. 
        The id is determined by comparing the port event time with the trial end times.
    """

    trial_ids = []
    current_trial = 1
    for current_time in sorted_port_in_times:
        if current_trial > len(trial_end_timestamps):
            trial_ids.append(current_trial)
        elif current_time <= trial_end_timestamps[current_trial - 1]:
            trial_ids.append(current_trial)
        else:
            current_trial += 1
            trial_ids.append(current_trial)
    return trial_ids


def find_trial_start_indices(trial_ids):
    """
    Determines the start indices for each trial.

    Args:
        trial_ids: List of trial ids for each port event.

    Returns:
        A list of start indices for each trial.
    """

    trial_start_indices = [0]
    for index, trial_id in enumerate(trial_ids[1:], 1):  # start enumerating from 1
        if trial_id != trial_ids[index-1]:
            trial_start_indices.append(index)
    return trial_start_indices


def align_trial_start_end_timestamps(
    trial_ids: list, 
    trial_start_indices: list, 
    trial_start_timestamps: list
) -> list:
    """
    Aligns trial start and end timestamps.

    Args:
        trial_ids: List of trial IDs, the length of which defines the iteration count.
        trial_start_indices: List of indices where a new trial starts in the list of trial IDs.
        trial_start_timestamps: List of timestamps corresponding to each start index.

    Returns:
        A list of aligned trial start times. If there's an index without a corresponding timestamp, numpy's nan is appended.
    """

    if len(trial_ids) < max(len(trial_start_indices), len(trial_start_timestamps)):
        raise ValueError("Length of trial_ids cannot be less than either trial_start_indices or trial_start_timestamps.")

    aligned_trial_timestamps = []
    counter = 0
    for i in range(len(trial_ids)):
        if counter + 1 < len(trial_start_indices) and i == trial_start_indices[counter+1]:
            counter += 1
        if counter < len(trial_start_timestamps):
            aligned_trial_timestamps.append(trial_start_timestamps[counter])
        else:
            aligned_trial_timestamps.append(np.nan)

    if len(trial_start_timestamps) != len(trial_start_indices):
        difference = abs(len(trial_start_timestamps) - len(trial_start_indices))
        if difference > 5:
            warnings.warn(f"Difference between trial_start_timestamps and trial_start_indices exceeds 5: {difference}")

    return aligned_trial_timestamps

def find_trial_start_and_poke1_camera_indices(camera_trigger_states: np.ndarray) -> Tuple[List[int], List[int]]:
    """
    Find indices in the camera timestamps where the trial starts and the first poke happens.

    Args:
        camera_trigger_states (np.ndarray): Array of trigger states from the camera.

    Returns:
        Tuple[List[int], List[int]]: Lists of indices where trial starts and the first poke happens.
    """
    ttl_change_indices = list(np.where(np.roll(camera_trigger_states, 1) != camera_trigger_states)[0])
    if ttl_change_indices[0] == 0:
        ttl_change_indices = ttl_change_indices[1:]

    poke1_camera_indices = ttl_change_indices[1::2]
    trial_start_camera_indices = ttl_change_indices[0::2]

    return trial_start_camera_indices, poke1_camera_indices


def generate_aligned_trial_end_camera_timestamps(trial_start_camera_indices: List[int], trial_ids: List[int], trial_start_indices: List[int], camera_timestamps: np.ndarray) -> List[Union[float, str]]:
    """
    Generate aligned timestamps for the end of trials based on camera timestamps.

    Args:
        trial_start_camera_indices (List[int]): List of indices where each trial starts.
        trial_ids (List[int]): List of trial ids for each port event.
        trial_start_indices (List[int]): List of start indices for each trial.
        camera_timestamps (np.ndarray): Array of camera timestamps.

    Returns:
        List[Union[float, str]]: List of aligned trial end timestamps.
    """
    end_indices = [item for index, item in enumerate(trial_start_camera_indices) if index > 0]
    aligned_trial_end_timestamps = align_trial_start_end_timestamps(trial_ids, trial_start_indices, camera_timestamps[end_indices])

    last_trial_length = len(trial_ids) - trial_start_indices[-1]
    if len(aligned_trial_end_timestamps) == len(trial_ids):
        del aligned_trial_end_timestamps[-last_trial_length:]

    aligned_trial_end_timestamps += ['NaN'] * last_trial_length
    return aligned_trial_end_timestamps


def align_firstpoke_camera_timestamps(trial_ids: List[int], trial_start_indices: List[int], trial_start_timestamps: List[float], all_port_references_sorted: List[float]) -> List[Union[float, str]]:
    """
    Align the timestamps of the first poke with the camera timestamps.

    Args:
        trial_ids (List[int]): List of trial ids for each port event.
        trial_start_indices (List[int]): List of start indices for each trial.
        trial_start_timestamps (List[float]): List of trial start timestamps.
        all_port_references_sorted (List[float]): Sorted list of all port references.

    Returns:
        List[Union[float, str]]: List of aligned first poke timestamps.
    """
    trial_timestamps_aligned = []
    counter = 0
    for index, item in enumerate(trial_ids):
        if all_port_references_sorted[index] == 2.0:
            if item > counter:
                counter += 1
                if len(trial_start_timestamps) != counter - 1:
                    trial_timestamps_aligned.append(trial_start_timestamps[counter-1])
                else:
                    trial_timestamps_aligned.append('NaN')
            else:
                trial_timestamps_aligned.append('NaN')
        else:
            trial_timestamps_aligned.append('NaN')
    return trial_timestamps_aligned


In [20]:
### handle data for optogenetics experiments

def handle_opto_stim_data(behavior_data, trial_settings, session_index, trial_ids):
    """
    Handles the optostim data. If optostim was enabled, creates a dataframe of optostim settings and
    aligns optostim trial data to the trial data. If optostim was not enabled, creates a list of 'NaN' values.
    If StimPoke was set to 5, includes additional variables in the settings dataframe.

    Parameters:
    behavior_data (dict): The behavior data dictionary.
    trial_settings (dict): The trial settings dictionary.
    session_index (int): The current session index.
    trial_ids (list): List of trial ids.

    Returns:
    optotrials_aligned (list): The list of aligned optostim trial data.
    optotrials_port_aligned (list): The list of aligned optostim port data.
    """
    if trial_settings['GUI']['OptoStim'] == 1:
        # Create opto settings as a dataframe
        opto_settings = pd.DataFrame({
            'StimPoke': [trial_settings['GUI']['StimPoke']],
            'PulsePower': [trial_settings['GUI']['PulsePower']],
            'OptoChance': [trial_settings['GUI']['OptoChance']],
            'PulseDuration': [trial_settings['GUI']['PulseDuration']],
            'PulseInterval': [trial_settings['GUI']['PulseInterval']],
            'TrainDuration': [trial_settings['GUI']['TrainDuration']],
            'TrainDelay': [trial_settings['GUI']['TrainDelay']] if 'TrainDelay' in trial_settings['GUI'] else [None]
        })

        # Pull out optotrials from data if available
        optotrials = behavior_data[session_index]['SessionData']['SessionVariables']['OptoStim']

        # Align these to dataframe
        executed_optotrials = optotrials[0:trial_ids[-1]]
        optotrials_aligned = align_opto_data(trial_ids, executed_optotrials)

        # Determine stimulated port
        if trial_settings['GUI']['StimPoke'] == 5:
            port_stimulated_data = behavior_data[session_index]['SessionData']['SessionVariables']['PortStimulated']
            optotrials_port = np.where(port_stimulated_data == 1)[1] + 1  # Adding 1 to match port numbers 1 through 4
        else:
            optotrials_port = [trial_settings['GUI']['StimPoke']] * len(trial_ids)

        # align ports to dataframe
        optotrials_port_aligned = align_data_to_trial_ids(trial_ids, optotrials_port)
    else:
        # No optostim so fill this column with NaNs
        optotrials_aligned = ['NaN'] * len(trial_ids)
        optotrials_port_aligned = ['NaN'] * len(trial_ids)
     
    return optotrials_aligned, optotrials_port_aligned

def align_data_to_trial_ids(trial_ids: List[int], data: List[int]) -> List[int]:
    """
    This function aligns the given data according to the trial ids.

    Args:
        trial_ids (List[int]): The list of trial ids.
        data (List[int]): The list of data to align.

    Returns:
        List[int]: The list of aligned data.
    """

    # Initialize the counter for executed trials and list for aligned trials
    data_counter = 0
    aligned_data = []

    # Iterate over the list of trial ids
    for index, trial_id in enumerate(trial_ids):
        # For the first trial, simply append the first data item
        if index == 0:
            aligned_data.append(data[data_counter])
        else:
            # If the current trial id is same as previous one, append the same data item
            if trial_id == trial_ids[index-1]:
                if data_counter < len(data):
                    aligned_data.append(data[data_counter])
                else:
                    aligned_data.append(float('nan'))
            else:
                # If the trial id has changed, increment the counter
                data_counter += 1
                # Check if data_counter has not exceeded the length of data
                if data_counter < len(data):
                    # Append the next data item
                    aligned_data.append(data[data_counter])
                else:
                    # If data_counter has exceeded the length of data, append NaN or any other suitable value
                    aligned_data.append(float('nan'))

    return aligned_data





In [21]:

def find_camera_timestamps(session_date: str, camera_directory: str, animal_id: str) -> Tuple[bool, Union[str, None]]:
    """
    Searches for timestamp files for a given animal and session date in the camera directory.
    
    Args:
        session_date (str): The date of the session, in 'yyyymmddHHMMSS' format.
        camera_directory (str): The path to the directory where camera files are stored.
        animal_id (str): The ID of the animal.
    
    Returns:
        Tuple[bool, Union[str, None]]: A tuple with a boolean indicating whether the timestamp file exists,
        and the path to the timestamp file, if it exists. If no timestamp file is found, the path is None.
    """
    # Format the session date in 'ddmmyy' format
    formatted_date = session_date[6:8] + session_date[4:6] + session_date[2:4]

    timestamps_exist = False
    timestamp_file_path = None

    # Check if the camera directory for the animal exists
    animal_camera_directory = os.path.join(camera_directory, animal_id)
    if not os.path.isdir(animal_camera_directory):
        return timestamps_exist, timestamp_file_path

    # Check if there is a directory for the session date
    if formatted_date in os.listdir(animal_camera_directory):
        session_date_directory = os.path.join(animal_camera_directory, formatted_date)

        # Look for timestamp file in the session date directory
        for filename in os.listdir(session_date_directory):
            # Check if the file is a csv file
            if filename.endswith('.csv'):
                # Extract timestamp from filename
                file_timestamp = filename[-12:-4].replace("_", "")

                # Check if the file was created before the session start time
                if int(file_timestamp) < int(session_date[9:15]):
                    timestamps_exist = True
                    timestamp_file_path = os.path.join(session_date_directory, filename)
                    break

    return timestamps_exist, timestamp_file_path


### Timestamp preprocessing:

def load_camera_timestamps(input_path: str) -> pd.DataFrame:
    """
    Load camera timestamps from a file.

    Args:
        input_path (str): Path to the file containing camera timestamps.

    Returns:
        pd.DataFrame: Dataframe containing camera timestamps.
    """
    camera_timestamps = pd.read_csv(input_path, sep=' ', header=None, names=['Trigger', 'Timestamp', 'blank'], index_col=2)
    del camera_timestamps['blank']
    return camera_timestamps

def convert_time(time: int) -> float:
    """
    Convert the time from a timestamp into seconds.

    Args:
        time (int): The timestamp to be converted.

    Returns:
        float: The timestamp converted into seconds.
    """
    cycle1 = (time >> 12) & 0x1FFF
    cycle2 = (time >> 25) & 0x7F
    seconds = cycle2 + cycle1 / 8000.
    return seconds

def uncycle(time: np.ndarray) -> np.ndarray:
    """
    Uncycle the time array.

    Args:
        time (np.ndarray): Time array to be uncycled.

    Returns:
        np.ndarray: Uncycled time array.
    """
    cycles = np.insert(np.diff(time) < 0, 0, False)
    cycle_index = np.cumsum(cycles)
    return time + cycle_index * 128

def convert_uncycle_timestamps(camera_timestamps: pd.DataFrame) -> np.ndarray:
    """
    Convert the timestamps into seconds and uncycle them.

    Args:
        camera_timestamps (pd.DataFrame): DataFrame containing camera timestamps.

    Returns:
        np.ndarray: Uncycled timestamps in seconds.
    """
    timestamps_s = []
    for index, row in camera_timestamps.iterrows():
        if row.Trigger > 0: 
            timestamp_new = convert_time(camera_timestamps.at[index, 'Timestamp'])
            timestamps_s.append(timestamp_new)
        else:    
            raise ValueError('Timestamps are broken')
    uncycled_timestamps = uncycle(timestamps_s)
    uncycled_timestamps = uncycled_timestamps - uncycled_timestamps[0]  # make first timestamp 0 and the others relative to this 
    return uncycled_timestamps

def check_timestamps(timestamps: np.ndarray, frame_rate: int) -> None:
    """
    Check for dropped frames in the timestamps.

    Args:
        timestamps (np.ndarray): Array of timestamps.
        frame_rate (int): Frame rate in frames per second.
    """
    frame_gaps = 1 / np.diff(timestamps)
    frames_dropped = np.sum((frame_gaps < frame_rate - 5) | (frame_gaps > frame_rate + 5))
    print('Frames dropped = ' + str(frames_dropped))
    plt.suptitle('Frame rate = ' + str(frame_rate) + 'fps', color = 'red')
    plt.hist(frame_gaps, bins=100)
    plt.xlabel('Frequency')
    plt.ylabel('Number of frames')

def find_trigger_states(camera_timestamps_raw: pd.DataFrame) -> np.ndarray:
    """
    Determine the trigger states from the raw camera timestamps.

    Args:
        camera_timestamps_raw (pd.DataFrame): DataFrame containing raw camera timestamps.

    Returns:
        np.ndarray: Array of trigger states.
    """
    down_state = camera_timestamps_raw['Trigger'][0]
    down_state_times = np.where(camera_timestamps_raw['Trigger'] == down_state)
    triggers_temp = np.ones(len(camera_timestamps_raw['Trigger']))
    triggers_temp[down_state_times] = 0
    return triggers_temp


In [22]:
## TODO check if all functions work with a single session and then think about how to combine them for multiple sessions]


In [45]:
animal_ids = ["EJT244","SP110", "SP111"]
input_directory = '/home/sthitapati/Documents/sequence_data/bpod_raw_data/'
output_directory = '/home/sthitapati/Documents/sequence_data/output/'
camera_directory =  '/home/sthitapati/Documents/sequence_data/FlyCap_SeqTracking/'
replace_existing = False 

In [24]:
def process_animal_data(
    animal_ids: List[str], 
    input_directory: str, 
    output_directory: str, 
    camera_directory: Optional[str] = None, 
    replace_existing: bool = False
) -> None:
    """
    Function to process data for each animal and each session.

    Args:
        animal_ids (List[str]): List of animal IDs.
        input_directory (str): Directory containing raw behavioral data for each animal.
        output_directory (str): Directory where processed data will be saved.
        camera_directory (Optional[str]): Directory containing the camera timestamp files for each animal, if available.
        replace_existing (bool): If True, existing processed data will be replaced. Defaults to False.

    Returns:
        None
    """

    # Iterate over each animal by its index and ID
    for animal_index, current_animal_id in enumerate(animal_ids):
        print ('Processing data for: ' + current_animal_id)

        # Construct the path for the current animal's data
        current_input_path = os.path.join(input_directory, current_animal_id, 'Sequence_Automated', 'Session Data/')

        # Load Behavioural data using the import_bpod_data_files function
        behavior_data, total_sessions, path, session_dates = import_bpod_data_files(current_input_path)

        # Initialize strings to store processed and skipped sessions
        processed_sessions = ''
        skipped_sessions = ''

        # Iterate over each session
        for session_index in range(total_sessions):

            # Create unique identifier for the session
            session_date = session_dates[session_index] + '_' + str(behavior_data[session_index]['__header__'])[-25:-22]

            # Set the save path depending on the session number
            if session_index < 10:
                save_path = os.path.join(output_directory, current_animal_id, 'Preprocessed', f'0{session_index}_{session_date}')
            else:
                save_path = os.path.join(output_directory, current_animal_id, 'Preprocessed', f'{session_index}_{session_date}')

            # Check if the directory exists already
            if not os.path.isdir(save_path):
                # If it doesn't exist, make the directory and set the processing flag to True
                os.makedirs(save_path)
                should_process = True
            else:
                # If it does exist, check the replace_existing flag to determine if data should be processed
                should_process = replace_existing

            # If processing flag is True, convert the data to a Python-friendly format
            if should_process:
                # Calculate final reward amount for the session
                final_reward_amounts = []
                for item in behavior_data[session_index]['SessionData']['SessionVariables']['TLevel']:
                    training_level = item
                    final_reward_amounts.append(behavior_data[0]['SessionData']['SessionVariables']['TrainingLevels'][training_level-1][4])

            # fetch trial_settings 
            trial_settings = behavior_data[session_index]['SessionData']['TrialSettings'][0]

            # Save out LED intensities and reward amounts on their own:
            led_intensities = pd.DataFrame({
                'Port2': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port2'],
                'Port3': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port3'],
                'Port4': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port4'],
                'Port5': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port5']
            })

            # Create a DataFrame for reward amounts for each port:
            reward_amounts = pd.DataFrame({
                'Port1': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port1'],
                'Port2': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port2'],
                'Port3': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port3'],
                'Port4': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port4']
            })
            
            # TODO: Save out reward amounts?
            # TODO: Save out LED intensities

            # Extract PortIn times for each port and check for errors
            port_in_times, port_out_times, port_references = extract_poke_times(behavior_data[session_index])

            # Remove 'nan' values (these represent times when part of the event was dropped by Bpod for some reason)
            fixed_port_in_times, fixed_port_out_times, fixed_port_references = remove_dropped_in_events(port_in_times, port_out_times, port_references)

            # Resort these times for consistent chronology
            sorted_port_in_times, sorted_port_out_times, sorted_port_references = sort_by_time(fixed_port_in_times, fixed_port_out_times, fixed_port_references)

            # Extract reward timestamps:
            reward_timestamps = extract_reward_timestamps(behavior_data[session_index])

            # Find indices corresponding to rewarded events and align them to poke events:
            rewarded_event_indices = find_rewarded_event_indices(sorted_port_in_times, sorted_port_references, reward_timestamps)

            # Remove 'NaN' entries from reward timestamps:
            reward_timestamps = np.asarray(reward_timestamps)
            reward_timestamps = reward_timestamps[np.logical_not(np.isnan(reward_timestamps))]
            reward_timestamps = list(reward_timestamps)

            # Align reward timestamps to the corresponding poke events:
            aligned_reward_timestamps = align_trigger_to_index(reward_timestamps, rewarded_event_indices, sorted_port_references)

            # Extract trial start timestamps:
            trial_start_timestamps = extract_trial_timestamps(behavior_data[session_index])
            
            # Extract trial end times:
            trial_end_timestamps = extract_trial_end_times(behavior_data[session_index])

            # Determine trial IDs:
            trial_ids = determine_trial_id(sorted_port_in_times, trial_end_timestamps)

            # Find trial start indices:
            trial_start_indices = find_trial_start_indices(trial_ids)

            # Align trial start timestamps to poke events:
            aligned_trial_start_timestamps = align_trial_start_end_timestamps(trial_ids, trial_start_indices, trial_start_timestamps)

            # Align trial end timestamps to poke events:
            aligned_trial_end_timestamps = align_trial_start_end_timestamps(trial_ids, trial_start_indices, trial_end_timestamps)

            # handle optogenetic stimulation
            optotrials_aligned, optotrials_port_aligned = handle_opto_stim_data(behavior_data, trial_settings, session_index, trial_ids)

            # Create empty lists to store intermediate rewards and LED intensities data for each trial
            intermediate_rewards_data = []
            led_intensities_data = []

            # Iterate over 'TLevel' items in SessionVariables
            for tlevel_item in behavior_data[session_index]['SessionData']['SessionVariables']['TLevel']:
                tlevel = tlevel_item
                # Append intermediate rewards and LED intensities data for the current trial
                intermediate_rewards_data.append(
                    list(behavior_data[session_index]['SessionData']['SessionVariables']['TrainingLevels'][tlevel-1][0:4])
                )
                led_intensities_data.append(
                    list(behavior_data[session_index]['SessionData']['SessionVariables']['TrainingLevels'][tlevel-1][6:10])
                )

            # Align intermediate rewards and LED intensities data with trial start indices
            aligned_led_intensities = align_trial_start_end_timestamps(trial_ids, trial_start_indices, led_intensities_data)
            aligned_intermediate_rewards = align_trial_start_end_timestamps(trial_ids, trial_start_indices, intermediate_rewards_data)


            # Align training level for each trial
            training_levels = align_data_to_trial_ids(trial_ids, behavior_data[session_index]['SessionData']['SessionVariables']['TLevel'])

           
            

In [25]:
process_animal_data(animal_ids=animal_ids, input_directory=input_directory, output_directory=output_directory, camera_directory=camera_directory, replace_existing=False)

Processing data for: EJT244
Processing data for: SP110
Processing data for: SP111


In [46]:
# Construct the path for the current animal's data
current_animal_id = animal_ids[2]
current_input_path = os.path.join(input_directory, current_animal_id, 'Sequence_Automated', 'Session Data/')

print(current_input_path)

/home/sthitapati/Documents/sequence_data/bpod_raw_data/SP111/Sequence_Automated/Session Data/


In [47]:
# Load Behavioural data using the import_bpod_data_files function
behavior_data, total_sessions, path, session_dates = import_bpod_data_files(current_input_path)


In [48]:
session_index = 3

In [49]:
# Calculate final reward amount for the session
final_reward_amounts = []
for item in behavior_data[session_index]['SessionData']['SessionVariables']['TLevel']:
    training_level = item
    final_reward_amounts.append(behavior_data[session_index]['SessionData']['SessionVariables']['TrainingLevels'][training_level-1][4])

# final_reward_amounts

In [50]:
# fetch trial_settings 
trial_settings = behavior_data[session_index]['SessionData']['TrialSettings'][0]
# trial_settings


In [51]:
# Save out LED intensities and reward amounts on their own:
led_intensities = pd.DataFrame({
    'Port2': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port2'],
    'Port3': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port3'],
    'Port4': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port4'],
    'Port5': behavior_data[session_index]['SessionData']['SessionVariables']['LEDIntensitys']['port5']
})
# print(led_intensities)
# Create a DataFrame for reward amounts for each port:
reward_amounts = pd.DataFrame({
    'Port1': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port1'],
    'Port2': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port2'],
    'Port3': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port3'],
    'Port4': behavior_data[session_index]['SessionData']['SessionVariables']['RewardAmount']['port4']
})
# print(reward_amounts)

In [52]:
# Extract PortIn times for each port and check for errors
port_in_times, port_out_times, port_references = extract_poke_times(behavior_data[session_index])

print(len(port_in_times))
print(len(port_out_times))
print(len(port_references))


2870
2870
2870


In [53]:
# Remove 'nan' values (these represent times when part of the event was dropped by Bpod for some reason)
fixed_port_in_times, fixed_port_out_times, fixed_port_references = remove_dropped_in_events(port_in_times, port_out_times, port_references)

print(len(fixed_port_in_times))
print(len(fixed_port_out_times))
print(len(fixed_port_references))

2870
2870
2870


In [54]:
# Resort these times for consistent chronology
sorted_port_in_times, sorted_port_out_times, sorted_port_references = sort_by_time(fixed_port_in_times, fixed_port_out_times, fixed_port_references)

print(len(sorted_port_in_times))
print(len(sorted_port_out_times))
print(len(sorted_port_references))


2870
2870
2870


In [55]:
# Extract reward timestamps:
reward_timestamps = extract_reward_timestamps(behavior_data[session_index])

# Find indices corresponding to rewarded events and align them to poke events:
rewarded_event_indices = find_rewarded_event_indices(sorted_port_in_times, sorted_port_references, reward_timestamps)

# Remove 'NaN' entries from reward timestamps:
reward_timestamps = np.asarray(reward_timestamps)
reward_timestamps = reward_timestamps[np.logical_not(np.isnan(reward_timestamps))]
reward_timestamps = list(reward_timestamps)

# Align reward timestamps to the corresponding poke events:
aligned_reward_timestamps = align_trigger_to_index(reward_timestamps, rewarded_event_indices, sorted_port_references)

# print(len(aligned_reward_timestamps))
# print(len(reward_timestamps))

In [56]:
# Align reward timestamps to the corresponding poke events:
aligned_reward_timestamps = align_trigger_to_index(reward_timestamps, rewarded_event_indices, sorted_port_references)

print(len(aligned_reward_timestamps))


2870


In [57]:

# Extract trial start timestamps:
trial_start_timestamps = extract_trial_timestamps(behavior_data[session_index])

print(len(trial_start_timestamps))


260


In [58]:

# Extract trial end times:
trial_end_timestamps = extract_trial_end_times(behavior_data[session_index])
print(len(trial_end_timestamps))


260


In [59]:
print(len(sorted_port_in_times))
print(type(sorted_port_in_times))
print(len(trial_end_timestamps))
print(type(trial_end_timestamps))

2870
<class 'numpy.ndarray'>
260
<class 'list'>


In [60]:

# Determine trial IDs:
trial_ids = determine_trial_id(sorted_port_in_times, trial_end_timestamps)
print(len(trial_ids))


2870


In [61]:

# Find trial start indices:
trial_start_indices = find_trial_start_indices(trial_ids)
print(len(trial_start_indices))


260


In [62]:

# Align trial start timestamps to poke events:
aligned_trial_start_timestamps = align_trial_start_end_timestamps(trial_ids, trial_start_indices, trial_start_timestamps)
print(len(aligned_trial_start_timestamps))

# Align trial end timestamps to poke events:
aligned_trial_end_timestamps = align_trial_start_end_timestamps(trial_ids, trial_start_indices, trial_end_timestamps)
print(len(aligned_trial_end_timestamps))

2870
2870


In [63]:
# handle optogenetic stimulation
optotrials_aligned, optotrials_port_aligned = handle_opto_stim_data(behavior_data, trial_settings, session_index, trial_ids)

print(len(optotrials_aligned))
print(len(optotrials_port_aligned))

2870
2870


In [64]:
# Create empty lists to store intermediate rewards and LED intensities data for each trial
intermediate_rewards_data = []
led_intensities_data = []

# Iterate over 'TLevel' items in SessionVariables
for tlevel_item in behavior_data[session_index]['SessionData']['SessionVariables']['TLevel']:
    tlevel = tlevel_item

    # Append intermediate rewards and LED intensities data for the current trial
    intermediate_rewards_data.append(
        list(behavior_data[session_index]['SessionData']['SessionVariables']['TrainingLevels'][tlevel-1][0:4])
    )
    led_intensities_data.append(
        list(behavior_data[session_index]['SessionData']['SessionVariables']['TrainingLevels'][tlevel-1][6:10])
    )

# Align intermediate rewards and LED intensities data with trial start indices
aligned_led_intensities = align_trial_start_end_timestamps(trial_ids, trial_start_indices, led_intensities_data)
aligned_intermediate_rewards = align_trial_start_end_timestamps(trial_ids, trial_start_indices, intermediate_rewards_data)

In [65]:
print(len(aligned_led_intensities))
print(len(aligned_intermediate_rewards))

2870
2870


In [66]:
# Align training level for each trial
training_levels = align_data_to_trial_ids(trial_ids, behavior_data[session_index]['SessionData']['SessionVariables']['TLevel'])
print(len(training_levels))

2870


In [67]:
type(session_dates[0])

str

In [72]:
for session_id in range(len(session_dates)):
    filedate = session_dates[session_id] + '_' + str(behavior_data[session_id]['__header__'])[-25:-22]
    print(f"date: {filedate}")

    timestamps_exist, timestamp_file_path = find_camera_timestamps(session_date=session_dates[session_id], camera_directory=camera_directory, animal_id=current_animal_id)

    print(timestamps_exist)
    print(timestamp_file_path)


date: 20230421_115225_Fri
False
None
date: 20230422_152143_Sat
False
None
date: 20230423_192319_Sun
False
None
date: 20230424_135325_Mon
False
None
date: 20230425_134730_Tue
False
None
date: 20230426_113747_Wed
False
None
date: 20230427_160502_Thu
False
None
date: 20230428_120644_Fri
False
None
date: 20230502_121059_Tue
False
None
date: 20230503_121831_Wed
False
None
date: 20230504_133913_Thu
False
None
date: 20230509_115531_Tue
False
None
date: 20230510_130339_Wed
False
None
date: 20230511_142715_Thu
False
None
date: 20230512_125459_Fri
False
None
date: 20230516_155130_Tue
False
None
date: 20230517_104733_Wed
False
None
date: 20230518_171602_Thu
False
None
date: 20230519_154240_Fri
False
None
date: 20230523_181113_Tue
False
None
date: 20230524_140934_Wed
False
None
