# Analyze sessions in batch from Phase 1 of AdaDrive (work in progress)

In [90]:
import sys
 
# setting path
sys.path.append('..')

from mna.utils.rnapp_data_format import read_all_lslpresets, return_metadata_from_name, event_data_from_data
import pickle, os
import pandas as pd
import numpy as np
import seaborn as sns
from IPython.display import display

import matplotlib.pyplot as plt
from mna.sessions.eye_session import process_session_eye
from mna.sessions.eeg_session import process_session_eeg
from mna.sessions.motor_session import process_session_motor
from mna.sessions.ecg_session import process_session_ecg
from os import listdir
from os.path import isfile, join
from mna.utils.rnapp_data_format import read_all_lslpresets, return_metadata_from_name, event_data_from_data
import pickle
from statannotations.Annotator import Annotator
from collections import defaultdict
from scipy import stats
import mne
import glob 
import random
import math
# 1. Read a RN App, converted pkl file, and create the metadata and data structure

In [18]:
import matplotlib
matplotlib.use('Agg')

# Prep

In [19]:
from mna.utils.rnapp_data_format import read_all_files

In [20]:
data_dir = "../data/"
lsl_dir = "../mna/LSLPresets/"
output_dir = '../output/batch_analysis/'
if not os.path.isdir(output_dir): os.makedirs(output_dir)
metadata_jsons = read_all_lslpresets(path_to_jsonfiles=lsl_dir)
onlyfiles = [f for f in listdir(data_dir) if isfile(join(data_dir, f)) and '.pkl' in f]

interrupted_sessions = [(13,1), (22,1)]
reference_ica = "sbj20ssn03"
save_data_pkl = True # save data into pickle files
save_ica_plts = False # save ICA components plots
epoch_raw_eeg = False # epoching raw data
motor_events = True
preturn = 1000
rs = 64 # random seed

# Motor-based break detection (11/16)

### Export re-referenced raw event data

In [22]:
def return_metadata_from_name(stream_name, metadata_jsons):
    for stream in metadata_jsons:
        if stream['StreamName'] == stream_name or stream['StreamName'] == stream_name.replace("_", "."):
            return stream
    return None
def add_trial_start_time(event_df, offset=0.01):
    trial_end_times = np.zeros(event_df.shape[0])
    trial_end_times[1:] = event_df.trial_end_time[0:-1]+offset # add a 0.01 second offset since the next trial starts immediately
    event_df.insert(0, "trial_start_time", trial_end_times)
interrupted_id_sessions = [(13,1), (22,1)]
all_dfs = []
#for each_file in onlyfiles:
for each_file in ['09_09_2022_14_24_33-Exp_adadrive-Sbj_17-Ssn_01.dats.pkl']:   
    input_path = data_dir + each_file

    sbj_id = each_file[each_file.find('Sbj_')+4:each_file.find('-Ssn')]
    ssn_no = each_file[each_file.find('Ssn_')+4:each_file.find('.dats')]
    
    if len(sbj_id) < 2: sbj = "sbj0"+sbj_id
    else: sbj = "sbj"+sbj_id
    if len(ssn_no) < 2: ssn = "ssn0"+ssn_no
    else: ssn = "ssn"+ssn_no

    with open(input_path, 'rb') as handle:
        rns_data = pickle.load(handle)
    
    for key in rns_data.keys():
        rns_data[key].append(return_metadata_from_name(key, metadata_jsons))
        
    ## Add metadata to data
    event_df = pd.DataFrame(rns_data['Unity_TrialInfo'][0], columns=rns_data['Unity_TrialInfo'][1],
                  index=rns_data['Unity_TrialInfo'][2]['ChannelNames']).T
    event_df = event_df.reset_index().rename(columns={'index': 'trial_end_time'})
    interrupted_ids = [p[0] for p in interrupted_id_sessions]
    interrupted_sessions = [p[1] for p in interrupted_id_sessions]
    
    # re-reference
    session_start_time = rns_data['Unity_MotorInput'][1][0]
    event_df.trial_end_time -= session_start_time
    
    # chunk data is always paired but offset
    if 'Unity_ChunkInfo' in rns_data:
        chunk_df = pd.DataFrame(rns_data['Unity_ChunkInfo'][0], columns=rns_data['Unity_ChunkInfo'][1],
                              index=rns_data['Unity_ChunkInfo'][2]['ChannelNames']).T
        chunk_df = chunk_df.reset_index().rename(columns={'index': 'chunk_timestamp'})
        event_df = pd.concat([event_df,chunk_df],axis=1)
        #chunk_df['chunk_timestamp'] = chunk_df.index
        #event_df = pd.merge_asof(event_df, chunk_df,
        #                         left_on="trial_start_time",right_index=True,
        #                         direction='nearest', tolerance=1)
    else:
        print(f"Unity_ChunkInfo not found")
        
    if event_df.iloc[0].ppid in interrupted_ids and event_df.iloc[0].session in interrupted_sessions:
        print('FIXING THE EVENT DF SINCE PID', event_df.iloc[0].ppid, 'SESSION', event_df.iloc[0].session, 'WAS INTERRUPTED')
        event_df = event_df.loc[~event_df.duplicated(subset=['ppid','session','block','number_in_block','trial'], keep='first'),:].reset_index(drop=True)
        last_freak_idx = event_df.loc[event_df.ppid == 0].index[-1]
        event_df = event_df[event_df.ppid != 0].reset_index(drop=True)
        event_df.loc[last_freak_idx:,'session'] = event_df.loc[last_freak_idx-1,'session']
        event_df.loc[last_freak_idx:,'block'] = event_df.loc[last_freak_idx:,'block'] + event_df.loc[last_freak_idx-1,'block']
        event_df.loc[last_freak_idx:,'trial'] = event_df.loc[last_freak_idx:,'trial'] + event_df.loc[last_freak_idx-1,'trial']
        event_df.loc[last_freak_idx:,'damage'] = event_df.loc[last_freak_idx:,'damage'] + event_df.loc[last_freak_idx-1,'damage']
        
    add_trial_start_time(event_df)
    all_dfs.append(event_df)

#pd.concat(all_dfs).to_csv('all_raw_trial_events.csv')

### Find breaks using motor

In [8]:
def find_breaks(nums, timestamps, break_time=5, inactive_value=0):
    start_intervals=[]
    end_intervals=[]
    inactive_duration = []
    l=timestamps[0]
    for i in range(1,len(nums)):
        if nums[i] != nums[i-1] and nums[i-1]==inactive_value and timestamps[i-1]-l > break_time:
            start_intervals.append(l)
            end_intervals.append(timestamps[i-1])
            inactive_duration.append(timestamps[i-1]-l)
            nums[i] != inactive_value
        if nums[i] != inactive_value:
            l=timestamps[i]
    return start_intervals, end_intervals, inactive_duration
all_dfs = []
for each_file in onlyfiles:
    input_path = data_dir + each_file
    print('input_path',input_path)
    sbj_id = each_file[each_file.find('Sbj_')+4:each_file.find('-Ssn')]
    ssn_no = each_file[each_file.find('Ssn_')+4:each_file.find('.dats')]
    
    if len(sbj_id) < 2: sbj = "sbj0"+sbj_id
    else: sbj = "sbj"+sbj_id
    if len(ssn_no) < 2: ssn = "ssn0"+ssn_no
    else: ssn = "ssn"+ssn_no

    with open(input_path, 'rb') as handle:
        rns_data = pickle.load(handle)

    ## Add metadata to data

    for key in rns_data.keys():
        rns_data[key].append(return_metadata_from_name(key, metadata_jsons))
        
    motor_df = pd.DataFrame(rns_data['Unity_MotorInput'][0], columns=rns_data['Unity_MotorInput'][1],
                      index=rns_data['Unity_MotorInput'][2]['ChannelNames']).T
    motor_df = motor_df.reset_index().rename(columns={'index': 'timestamp'})
    session_start_time = motor_df.timestamp.iloc[0]

    motor_df.timestamp -= session_start_time

    start_intervals, end_intervals, inactive_duration = find_breaks(motor_df['throttle_input'].tolist(), motor_df.timestamp.tolist(),break_time=5, inactive_value=0)
    sub_df = pd.DataFrame({'sbj_ssn': sbj+ssn, 'start_intervals': start_intervals,'end_intervals': end_intervals, 'inactive_duration': inactive_duration})
    all_dfs.append(sub_df)
# pd.concat(all_dfs).to_csv('all_motor_break_events.csv')

input_path ../data/09_09_2022_14_24_33-Exp_adadrive-Sbj_17-Ssn_01.dats.pkl


### Post-annotation, clean up the motor data and create simulated input for Unity

In [41]:
trial_df = pd.read_csv(f"{data_dir}/annotated/all_raw_trial_events.csv")

In [124]:
df = pd.read_excel(f"{data_dir}/annotated/all_motor_break_events.xlsx")

all_dfs = []

for sub in df.sbj_ssn.unique():
# for sub in ['sbj21ssn03']:
    start_intervals = []
    simulate = False
    sub_df = df.loc[df.sbj_ssn==sub].reset_index(drop=True)
    sbj_ssn = sub_df.sbj_ssn.iloc[0]
    if sub_df.regular_calibration_intervals.iloc[0]: simulate = True # only simulate if we confirmed that the pre-calibration intervals are good
    if not simulate: continue
    
    # start_intervals = []
    drive_type = []
    start_intervals.append({'sbj_ssn': sbj_ssn, 'start_interval': sub_df.iloc[0].start_intervals, 'type_interval': 'brake', 'type_trial':'practice'})
    
    calib_trial_count = 5
    calibration_interval_start = sub_df.loc[sub_df.calibration_trials==1,'start_intervals'] # confirmed first calibration interval 
    if len(calibration_interval_start) == 1:
        calib_idx = sub_df.loc[sub_df.calibration_trials==1,'start_intervals'].index.values[0]
        calib_df = sub_df.iloc[:calib_idx+1]
        total_driving_time = np.nansum(calib_df.start_intervals.shift(-1)-calib_df.end_intervals)
        for index, interval in calib_df.iloc[1:].iterrows():
            average_trial_duration = total_driving_time/calib_trial_count
            start_drive = calib_df.iloc[index-1].end_intervals
            end_drive = interval.start_intervals
            if index == calib_df.index[-1]: # if it's the last chance to get calibration trial, ensure we will get 5 trials
                num_complete_trials = calib_trial_count
            else:
                num_complete_trials = math.floor((end_drive-start_drive)/average_trial_duration)
            if num_complete_trials == 0:
                d_time = end_drive-start_drive
                total_driving_time -= d_time # we won't be able to complete a trial but make a little progress
                start_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': start_drive+d_time, 'type_interval':'brake', 'type_trial':'practice'})
            else:
                start_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': start_drive, 'type_interval':'drive', 'type_trial':'practice'})
            interval_done = start_drive
            for trial in range(num_complete_trials):
                trial_duration = random.uniform(average_trial_duration-.5,average_trial_duration+.5)
                # trial_duration = average_trial_duration
                total_driving_time -= trial_duration
                interval_done = interval_done+trial_duration
                if trial == num_complete_trials-1: # last one is a brake
                    start_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': interval_done, 'type_interval':'brake', 'type_trial':'practice'})
                else:
                    start_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': interval_done, 'type_interval':'drive', 'type_trial':'practice'})
                left_over = end_drive-interval_done
                average_trial_duration = left_over/(num_complete_trials-trial-1) 
                calib_trial_count -= 1
    
    
    sub = int(sbj_ssn.split('sbj')[1].split('ssn')[0])
    ses = int(sbj_ssn.split('sbj')[1].split('ssn')[1])
    def process_sub_df(calib_idx, trial_df, sub, ses, between_calib=False, calib_block_to_end=1):
        # calib_block_to_end == 1 if wanting to get the number of trials after the calibration block  (assuming 2 is first calib block)
        # calib_block_to_end == 11 if wanting to get the number of trials after the second calibraiton block (assuming 12 is second calib block)
        # calib_idx should map onto when in sub_df the calibration corresponds to (motor breaks)
        sub_df_intervals = []
        if between_calib:
            calib_df = sub_df.iloc[calib_idx:between_calib+1].reset_index(drop=True)
            trials_between_calib = len(trial_df.loc[(trial_df.ppid==sub) & (trial_df.session==ses) & (trial_df.block>1) & (trial_df.block<=11)])
        elif calib_block_to_end:
            calib_df = sub_df.iloc[calib_idx:].reset_index(drop=True)
            trials_between_calib = len(trial_df.loc[(trial_df.ppid==sub) & (trial_df.session==ses) & (trial_df.block>calib_block_to_end)])
        calib_trial_count = trials_between_calib
        total_driving_time = np.nansum(calib_df.start_intervals.shift(-1)-calib_df.end_intervals)
        for index, interval in calib_df.iloc[1:].iterrows():
            average_trial_duration = total_driving_time/calib_trial_count
            start_drive = calib_df.iloc[index-1].end_intervals
            end_drive = interval.start_intervals
            # sub_df_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': start_drive, 'type_interval':'drive'})
            num_complete_trials = math.floor((end_drive-start_drive)/average_trial_duration)
            if num_complete_trials == 0:
                d_time = end_drive-start_drive
                total_driving_time -= d_time # we won't be able to complete a trial but make a little progress
                sub_df_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': start_drive+d_time, 'type_interval':'brake', 'type_trial':'voice'})
            else:
                sub_df_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': start_drive, 'type_interval':'drive', 'type_trial':'voice'})
            interval_done = start_drive
            for trial in range(num_complete_trials):
                trial_duration = random.uniform(average_trial_duration-.5,average_trial_duration+.5)
                # trial_duration = average_trial_duration
                total_driving_time -= trial_duration
                interval_done = interval_done+trial_duration
                if trial == num_complete_trials-1: # last one is a brake
                    sub_df_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': interval_done, 'type_interval':'brake', 'type_trial':'voice'})
                else:
                    sub_df_intervals.append({'sbj_ssn': interval.sbj_ssn, 'start_interval': interval_done, 'type_interval':'drive', 'type_trial':'voice'})
                left_over = end_drive-interval_done
                average_trial_duration = left_over/(num_complete_trials-trial-1) 
                calib_trial_count -= 1
        return sub_df_intervals, calib_df
    
    calibration_interval_start = sub_df.loc[sub_df.calibration_trials==2,'start_intervals'] # confirmed second calibration interval 
    if len(calibration_interval_start) == 1:
        calib_idx_2 = sub_df.loc[sub_df.calibration_trials==2,'start_intervals'].index.values[0]
        sub_df_intervals, calib_df = process_sub_df(calib_idx, trial_df, sub, ses, between_calib=calib_idx_2)
        start_intervals.extend(sub_df_intervals)
        # now take care of the rest
        calib_idx = calib_idx_2
        sub_df_intervals, calib_df = process_sub_df(calib_idx, trial_df, sub, ses, between_calib=False, calib_block_to_end=11)
        
        start_intervals.extend(sub_df_intervals)
    else:
        sub_df_intervals, calib_df = process_sub_df(calib_idx, trial_df, sub, ses, between_calib=False, calib_block_to_end=1)
        start_intervals.extend(sub_df_intervals)
    #if len(calibration_interval_start) == 1
    #    calib_idx_2 = sub_df.index[-1]
    #    sub_df_intervals = process_sub_df(calib_idx_2, trial_df, sub, ses)
    #    start_intervals.extend(sub_df_intervals)
    
    out_df = pd.DataFrame(start_intervals)
    out_df['end_interval'] = out_df.start_interval.shift(-1)
    out_df.loc[out_df.index[-1],'end_interval'] = calib_df.end_intervals.iloc[-1]
    out_df = out_df[['sbj_ssn', 'start_interval', 'end_interval', 'type_interval','type_trial']]
    all_dfs.append(out_df)
all_dfs = pd.concat(all_dfs)
all_dfs.to_csv('motor_events_simulation.csv', index=False)