# Analyze sessions in batch from Phase 1 of AdaDrive (work in progress)

In [None]:
path_to_base_package = '../..'
import sys
# setting path
sys.path.append(f"{path_to_base_package}")
import mne
mne.viz.set_3d_backend('pyvistaqt')
mne.viz.set_3d_options(antialias=False) 

import json
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from mna.utils.data_access import *
from mna.utils.analysis import *
from mne.datasets import fetch_fsaverage
import pickle
from collections import defaultdict

# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)

# Aux functions, read files

In [None]:
output_dir = f"{path_to_base_package}/output/batch_analysis_motor/"
output_dir_non_baseline_non_average = f"{path_to_base_package}/output/batch_analysis_motor/saved_files_non_baseline_non_average/" # saved files directory (trial or motor)
remove_sessions = [(15,1),(22,1)]
rel_regions, all_region = get_relevant_channels()
rel_labels, rel_mappings = get_relevant_labels_mappings(path_to_base_package)

pupil_df = pd.read_csv(f"{path_to_base_package}/output/pupil_exposure/participant_level_exposure_fits.csv")
trial_dfs = pd.read_csv(f"{output_dir}all_results.csv")
motor_dfs = read_motor_csvs(output_dir)
motor_dfs['post_steer_event_raw'] = motor_dfs['post_steer_event_raw'].apply(str_list_to_list)
motor_epochs = get_motor_epochs(output_dir_non_baseline_non_average)
#low_motor_sensor = motor_epochs["Steer_Wheel_Degree_Categorical == 'Low'"]
#high_motor_sensor = motor_epochs["Steer_Wheel_Degree_Categorical == 'High'"]
#low_pupil = motor_epochs["pupil_bin == '{}'".format('low')]
#high_pupil =motor_epochs["pupil_bin == '{}'".format('high')]
exposure_epochs = get_exposure_epochs(f"{path_to_base_package}/output/exposure/exposure_epochs.pickle")
#low_motor_sensor.apply_proj()
#high_motor_sensor.apply_proj()

p_val_criteria = 0.05
preturn = 1000

# Clean up dfs

In [None]:
# seaborn
import math 
sns.set(font_scale=1.2)
sns.set_palette("tab10")
from mna.utils.batch_feature_extraction import clean_up_adadrive_trials

motor_outlier_cols = ['abs_sum_delta_steer_input']
cols_to_outlier_detect = ['bpm', 'sdnn', 'rmssd', 'pnn50']
experimental_cols = ['spoken_difficulty', 'trial_duration', 'density', 'trial_damage']
eye_cols = ['Left Pupil Diameter', "NSLR_count_Fixation", "NSLR_count_Saccade",
            'NSLR_mean_duration_Fixation', 'NSLR_mean_duration_Saccade',
            'NSLR_first_onset_Fixation', 'NSLR_first_onset_Saccade']
ecg_cols = ['bpm', 'sdnn', 'rmssd', 'pnn50']  # rmssd = parasympathetic
motor_cols = ['abs_sum_delta_steer_input', 'abs_sum_delta_brake_input', 'abs_sum_delta_throttle_input']
def remove_motor_overlaps(test_df):
    test_df = test_df.reset_index(drop=True)
    trial_list = list(zip(test_df.trial_start_time, test_df.trial_end_time, test_df.index))
    overlaps = []
    for i in range(1,len(trial_list)):
        base_data = trial_list[i]
        check_data = trial_list[i-1]
        if base_data[1] > check_data[0] and base_data[1] <= check_data[1]:
            overlaps.append((base_data, check_data, base_data[1]-check_data[0], base_data[2], check_data[2]))
            assert base_data[1]-check_data[0] != 1, 'Major issue, repeating trials found. Double check'
    drop_rows = []
    for overlap in overlaps:
        row_1 = test_df.iloc[overlap[-1]]
        row_2 = test_df.iloc[overlap[-2]]
        if np.argmax([row_1.Abs_Steer_Wheel_Degree, row_2.Abs_Steer_Wheel_Degree]) == 0: # if row 1 has larger steer motion, drop the other
            drop_rows.append(overlap[-2])
        else:
            drop_rows.append(overlap[-1])
    test_df = test_df.drop(drop_rows,axis=0)
    return test_df

def clean_up_trials(input_df):
    all_dfs_final = clean_up_adadrive_trials(input_df.copy())
    # damage change
    all_dfs_final = all_dfs_final.sort_values(by=['ppid', 'session', 'block', 'trial'])
    # nan, outliers

    all_dfs_final['NSLR_first_onset_Fixation'] = all_dfs_final['NSLR_first_onset_Fixation'] - all_dfs_final[
        'trial_start_time']
    all_dfs_final['NSLR_first_onset_Saccade'] = all_dfs_final['NSLR_first_onset_Saccade'] - all_dfs_final[
        'trial_start_time']

    all_dfs_final[
        'throttle_over_brake'] = all_dfs_final.abs_sum_delta_throttle_input / all_dfs_final.abs_sum_delta_brake_input
    return all_dfs_final


trial_dfs = clean_up_trials(trial_dfs)
trial_dfs = trial_dfs.loc[~trial_dfs.ppid_session.isin([f"{es[0]}_{es[1]}" for es in remove_sessions])]
motor_dfs = clean_up_trials(motor_dfs)
print(f"removing ovlerlapping motor trials, starting epoch count {len(motor_dfs)}")
motor_dfs = remove_motor_overlaps(motor_dfs)
print(f"post removal epoch count {len(motor_dfs)}")
# luminance effect removal from pupil diameter
trial_dfs['Raw Left Pupil Diameter'] = trial_dfs['Left Pupil Diameter']
motor_dfs['Raw Left Pupil Diameter'] = motor_dfs['Left Pupil Diameter']
trial_dfs = trial_dfs.reset_index(drop=True)
adjustments=[]
for index, row in trial_dfs.iloc[1:].iterrows():
    last_ppid = trial_dfs.iloc[index - 1].ppid
    last_session = trial_dfs.iloc[index - 1].session
    last_trial = trial_dfs.iloc[index - 1].trial
    last_opacity = trial_dfs.iloc[index - 1].density
    if ((row.ppid == last_ppid) & (row.session == last_session) & (row.trial == last_trial + 1)):  # if continuous
        # if there is a significant effect of opacity on pupil
        if pupil_df.loc[pupil_df['sub'] == last_ppid, 'p_opacities'].values < p_val_criteria:
            this_opacity = row.density
            this_pupil_diameter = row['Left Pupil Diameter']
            weight = pupil_df.loc[pupil_df['sub'] == last_ppid, 'w_opacities']
            adjustment = ((this_opacity - last_opacity) * weight).values[0]
            trial_dfs.iloc[index, trial_dfs.columns.get_loc('Left Pupil Diameter')] += adjustment
            # this needs to be converted to array b/c of pandas issues
            old_pupil_value = np.array(motor_dfs.loc[(motor_dfs.ppid == last_ppid) & (motor_dfs.session == last_session) & (
                        motor_dfs.trial == last_trial + 1), 'Left Pupil Diameter']) 
            motor_dfs.loc[(motor_dfs.ppid == last_ppid) & (motor_dfs.session == last_session) & (
                        motor_dfs.trial == last_trial + 1), 'Left Pupil Diameter'] = (old_pupil_value-adjustment).T  # update motor df too
            # do also for motor_epochs
            old_pupil_value = motor_epochs.metadata.loc[(motor_epochs.metadata.ppid == last_ppid) &
                                      (motor_epochs.metadata.session == last_session) &
                                      (motor_epochs.metadata.trial == last_trial + 1), 'Left Pupil Diameter']
            motor_epochs.metadata.loc[(motor_epochs.metadata.ppid == last_ppid) &
                                      (motor_epochs.metadata.session == last_session) &
                                      (motor_epochs.metadata.trial == last_trial + 1), 'Left Pupil Diameter'] = (old_pupil_value-adjustment).T
# pupil bins
motor_dfs['pupil_bin'] = motor_dfs.groupby(['ppid'])['Left Pupil Diameter'].transform(
    lambda x: pd.qcut(x, 2, labels=['low', 'high']))
trial_dfs['pupil_bin'] = trial_dfs.groupby(['ppid'])['Left Pupil Diameter'].transform(
    lambda x: pd.qcut(x, 2, labels=['low', 'high']))
motor_epochs.metadata['pupil_bin'] = motor_epochs.metadata.groupby(['ppid'])['Left Pupil Diameter'].transform(
    lambda x: pd.qcut(x, 2, labels=['low', 'high']))
motor_dfs['pupil_bin_encoded'] = motor_dfs.groupby(['ppid'])['Left Pupil Diameter'].transform(
    lambda x: pd.qcut(x, 2, labels=[0, 1]))
trial_dfs['pupil_bin_encoded'] = trial_dfs.groupby(['ppid'])['Left Pupil Diameter'].transform(
    lambda x: pd.qcut(x, 2, labels=[0, 1]))
motor_epochs.metadata['pupil_bin_encoded'] = motor_epochs.metadata.groupby(['ppid'])['Left Pupil Diameter'].transform(
    lambda x: pd.qcut(x, 2, labels=[0, 1]))

# participant-level binning of motor data, replaces the session-level info already there
motor_dfs = get_motor_intensity_info(motor_dfs)
motor_epochs.metadata = get_motor_intensity_info(motor_epochs.metadata)


# ensure that epochs that removed from motor epochs are also removed from EEG analysis (we don't do the reverse since we have non-eeg, usable data)
def get_merged_motor_epochs(input_mepochs, input_motor_dfs):
    df_all = input_mepochs.metadata[['ppid', 'session', 'block', 'trial']].merge(input_motor_dfs[['ppid', 'session', 'block', 'trial']].drop_duplicates(), on=['ppid', 'session', 'block', 'trial'], 
                    how='inner', indicator=True)
    i1 = input_mepochs.metadata.set_index(['ppid', 'session', 'block', 'trial']).index
    i2 = df_all.set_index(['ppid', 'session', 'block', 'trial']).index
    input_mepochs = input_mepochs[i1.isin(i2)]
    # motor_epochs.set_eeg_reference('average') # custom eeg reference is not allowed for MNE source modeling
    input_mepochs.apply_baseline((-((preturn+250) / 1000), -((preturn) / 1000)))
    return input_mepochs
motor_epochs = get_merged_motor_epochs(motor_epochs, motor_dfs)
print(pupil_df[['sub','w_opacities','const','p_opacities']].to_latex(index=False,float_format="{:0.2f}".format))

# eLORETA

## Load forward model, create inverse operator

In [None]:

subjects_dir = os.path.dirname(fs_dir)

# The files live in:
subject = 'fsaverage'
trans = 'fsaverage'  # MNE has a built-in fsaverage transformation
'''
select the boundary element model, note that the source data has been downsampled by a factor of 5 
(i.e. ico == 5, https://mne.tools/stable/generated/mne.setup_source_space.html#mne.setup_source_space)
and the BEM has been downsampled by a factor of 5 (i.e. ico == 4, see here: https://mne.tools/stable/generated/mne.make_bem_model.html)
implications here: https://brainder.org/2016/05/31/downsampling-decimating-a-brain-surface/
'''
src_fname = os.path.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
bem = os.path.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')

eeg_montage='biosemi64'
info = motor_epochs.info

# Read and set the EEG electrode locations, which are already in fsaverage's
# space (MNI space) for standard_1020:
montage = mne.channels.make_standard_montage(eeg_montage)

# Check that the locations of EEG electrodes is correct with respect to MRI
# mne.viz.plot_alignment(
#    info, src=src_fname, eeg=['original', 'projected'], trans=trans,
#    show_axes=False, mri_fiducials=True, dig='fiducials')
fwd = mne.make_forward_solution(info, trans=trans, src=src_fname,
                                bem=bem, eeg=True, n_jobs=None)

In [None]:
method = "eLORETA"
snr = 3.
lambda2 = 1. / snr ** 2
cov = mne.compute_covariance(exposure_epochs, method='auto') # note this is not average referenced
cov.plot(exposure_epochs.info)
inverse_operator = mne.minimum_norm.make_inverse_operator(
    info, fwd, cov)


# Get time courses

In [None]:
def get_time_courses(input_epochs, input_labels):
    # Average the source estimates within each label of the cortical parcellation
    # and each sub-structure contained in the source space.
    # When mode = 'mean_flip', this option is used only for the cortical labels.
    src = inverse_operator['src']
    rel_stcs = mne.minimum_norm.apply_inverse_epochs(input_epochs, inverse_operator,
                                    lambda2=1.0 / snr ** 2, verbose=False,
                                    method="eLORETA", pick_ori="normal")
    label_ts = mne.extract_label_time_course(
        rel_stcs, input_labels, src, mode='mean_flip', allow_empty=True,
        return_generator=True, verbose=False)
    return label_ts

def get_all_tcs(output_dir, overwrite=False):
    if not overwrite and os.path.isfile(f"{output_dir}source_time_courses.pickle"):
        all_cond_tcs_df = pickle.load(open(f"{output_dir}source_time_courses.pickle", 'rb'))
        return all_cond_tcs_df
    else:
        all_cond_tcs = []
        pps = motor_epochs.metadata['ppid'].unique()
        for this_pid in pps:
            print('this_pid',this_pid)
            this_pid_df = motor_epochs.metadata[motor_epochs.metadata['ppid']==this_pid]
            nested_df = this_pid_df[this_pid_df['Steer_Wheel_Degree_Categorical']=='Low']
            st = set(nested_df.trial_start_time)
            relevant_indices = [i for i, e in enumerate(motor_epochs.metadata.trial_start_time) if e in st]
            low_pp_epochs = motor_epochs[relevant_indices] # note that motor_epochs index is not the same as the motor_epochs.metadata df index so we need to do this

            nested_df = this_pid_df[this_pid_df['Steer_Wheel_Degree_Categorical']=='High']
            st = set(nested_df.trial_start_time)
            relevant_indices = [i for i, e in enumerate(motor_epochs.metadata.trial_start_time) if e in st]
            high_pp_epochs = motor_epochs[relevant_indices] # note that motor_epochs index is not the same as the motor_epochs.metadata df index so we need to do this

            ret_tcs = get_time_courses(low_pp_epochs, rel_labels)
            global_trials = list(low_pp_epochs.metadata.trial)
            global_trial_starts = list(low_pp_epochs.metadata.trial_start_time)
            global_sessions = list(low_pp_epochs.metadata.session)
            all_low_tcs = []
            for indx, t in enumerate(ret_tcs):
                trial_df = pd.DataFrame(t)
                trial_df = trial_df.T
                trial_df.columns = [l.name for l in rel_labels]
                trial_df['motor_event_trial'] = indx
                trial_df['trial'] = global_trials[indx]
                trial_df['trial_start_time'] = global_trial_starts[indx]
                trial_df['session'] = global_sessions[indx]
                trial_df['sample'] = trial_df.index
                trial_df['pid'] = this_pid
                all_low_tcs.append(trial_df)
            all_low_tcs = pd.concat(all_low_tcs)
            all_low_tcs['cond'] = 'low'

            ret_tcs = get_time_courses(high_pp_epochs,rel_labels)
            global_trials = list(high_pp_epochs.metadata.trial)
            global_trial_starts = list(high_pp_epochs.metadata.trial_start_time)
            global_sessions = list(high_pp_epochs.metadata.session)
            all_high_tcs = []
            for indx, t in enumerate(ret_tcs):
                trial_df = pd.DataFrame(t)
                trial_df = trial_df.T
                trial_df.columns = [l.name for l in rel_labels]
                trial_df['motor_event_trial'] = indx
                trial_df['trial'] = global_trials[indx]
                trial_df['trial_start_time'] = global_trial_starts[indx]
                trial_df['session'] = global_sessions[indx]
                trial_df['sample'] = trial_df.index
                trial_df['pid'] = this_pid
                all_high_tcs.append(trial_df)
            all_high_tcs = pd.concat(all_high_tcs)
            all_high_tcs['cond'] = 'high'
            all_cond_tcs.append(pd.concat([all_low_tcs,all_high_tcs]))
            
        all_cond_tcs_df = pd.concat(all_cond_tcs)
        all_cond_tcs_df = pd.melt(all_cond_tcs_df, id_vars=['motor_event_trial','trial','trial_start_time','session','sample','pid','cond'],value_name='activation',var_name='source_region')
        with open(f"{output_dir}source_time_courses.pickle", 'wb') as handle_ica:
            pickle.dump(all_cond_tcs_df, handle_ica, protocol=pickle.HIGHEST_PROTOCOL)
    return all_cond_tcs_df


all_tcs = get_all_tcs(output_dir_non_baseline_non_average,overwrite=False)
all_tcs['hemi'] = all_tcs.source_region.apply(lambda x: x.split('-')[1])
all_tcs['source_region'] = all_tcs.source_region.apply(lambda x: rel_mappings[x])

In [None]:
def get_epoched_tcs(all_tcs_tcs,relvant_labels,relvant_mappings):
    '''
    Returns long form of all_tcs data, can be filtered input
    '''
    all_dfs = []
    all_tcs_src = defaultdict(list)
    for name,group in all_tcs_tcs.groupby(['pid','motor_event_trial', 'cond', 'trial','source_region','hemi']):  # loop through each epoch
        group['baseline_corr_activation'] = group.activation - np.mean(group.activation.values[0:int(.25*128)]) # save baseline-adjusted data
        all_dfs.append(group)
        all_tcs_src[(group.source_region.iloc[0],group.hemi.iloc[0])].append(group.activation.values)
    all_tcs_tcs = pd.concat(all_dfs)
    all_tcs_tcs['time'] = (all_tcs_tcs['sample']-(1.25*128))/128 # sample to time

    ordered_input_tcs = []
    for l in relvant_labels:
        ordered_input_tcs.append(all_tcs_src[(relvant_mappings[l.name],l.name.split('-')[1])])
    ordered_input_tcs = np.swapaxes(np.array(ordered_input_tcs),0,1)
    return ordered_input_tcs, all_tcs_tcs
