# 1. Library

In [None]:
import sys
 
# setting path
sys.path.append('..')

import pandas as pd
import mne
import numpy as np
import autoreject
import matplotlib.pyplot as plt
import pandas as pd
import os
import scipy
import sklearn

from mna.sessions.eye_session import process_session_eye
from mna.sessions.eeg_session import process_session_eeg
from mna.sessions.motor_session import process_session_motor
from mna.sessions.ecg_session import process_session_ecg
from mna.utils.batch_feature_extraction import clean_up_adadrive_trials

from mne.parallel import parallel_func
from mne_features.univariate import compute_hjorth_mobility,compute_pow_freq_bands
from mne.preprocessing import corrmap

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier

from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
import seaborn as sns

from sklearn.metrics import mean_squared_error, r2_score
from sklearn.ensemble import RandomForestRegressor

from os import listdir
from os.path import isfile, join
from mna.utils.rnapp_data_format import read_all_lslpresets, return_metadata_from_name, event_data_from_data, read_event_data
import pickle

# 2. Read a RN App, converted pkl file

In [None]:
data_dir = "../data/"
lsl_dir = "../mna/LSLPresets/"
output_dir = '../output/'
pickle_dir = f'{output_dir}saved_files/pickle_files/'
cvs_xlsx_dir = f'{output_dir}saved_files/cvs_xlsx_files/'

timestamp_fixer_path = f"{data_dir}annotated/fit_timestamp_adjuster.pkl"

if not os.path.isdir(output_dir): os.makedirs(output_dir)

metadata_jsons = read_all_lslpresets(path_to_jsonfiles=lsl_dir)
onlyfiles = [f for f in listdir(data_dir) if isfile(join(data_dir, f)) and '.pkl' in f]

ts_fixer = pickle.load(open(timestamp_fixer_path, 'rb')) # features == 'processed_trial_duration',  'processed_trial_duration_1', 'lsl_timestamps'

# all_dfs = pd.read_excel("output/all_results_cleaned.xlsx")
# all_dfs = clean_up_adadrive_trials(all_dfs)

save_data_pkl = True # save data into pickle files
save_ica_plts = True # save ICA components plots
epoch_raw_eeg = True # epoching raw data
motor_events = False
seed = 64 # random state

interrupted_sessions = [(13,1), (22,1)]
# remove_sessions = [(13,1),(15,1),(22,1)]
remove_sessions = [(13,1),(15,1),(22,1),(22,102)]

In [None]:
def eeg_features(df, data_type = 'processed', features = 'all', label_source = 'Steering_Wheel_Degree_Encoded', 
                 cleaned_up = False):
    
    if data_type == 'processed':
        first_electrode_column_name = "Fp1_4-8_Hz_Power"
        last_electrode_column_name = "O2_32-55_Hz_Sample_entropy"
        autoreject_column_name = "autorejected"
    elif data_type == 'raw':
        first_electrode_column_name = "Fp1_4-8_Hz_Power_raw"
        last_electrode_column_name = "O2_32-55_Hz_Sample_entropy_raw"
        autoreject_column_name = "autorejected_raw"
        
    first_electrode_idx = df.columns.get_loc(first_electrode_column_name)
    last_electrode_idx = df.columns.get_loc(last_electrode_column_name)

    # with autoreject
    valid_trial = (df[label_source].notnull()) & (df[autoreject_column_name] == False)
    
    all_eeg_features = df.iloc[:,first_electrode_idx:last_electrode_idx+1] # all features in cleaned up data
    
    if features == 'all': 
        eeg_features = all_eeg_features
    else:
        features = "|".join(map(str,features))
        eeg_features = all_eeg_features.loc[:, all_eeg_features.columns.str.contains(features)]
    
    if cleaned_up:
        return np.asarray(eeg_features[valid_trial]), np.asarray(df[label_source][valid_trial])
    else:
        return eeg_features

In [None]:
def eye_features(df, features = "pupil", label_source = 'Steering_Wheel_Degree_Encoded', cleaned_up = False):
    pupil_diameter = ['Left Pupil Diameter','Right Pupil Diameter']
    
    if features == 'pupil':
        # pupil_diameter.append(label_source)
        eye_df = df[pupil_diameter]
    else:
        eye_feature = features
        # eye_feature.append(label_source)
        eye_df = df[eye_feature]
        
    if cleaned_up:
        eye_df = eye_df.join(df[label_source]).dropna()
        return np.asarray(eye_df.iloc[:,0:-1]), np.asarray(eye_df.iloc[:,-1])
    else:
        return eye_df

In [None]:
def ecg_features(df, features = "all", label_source = 'Steering_Wheel_Degree_Encoded', cleaned_up = False):
    ecg_feature_first = df.columns.get_loc("bpm")
    ecg_feature_last = df.columns.get_loc("breathingrate")
    
    if features == 'all':
        ecg_df = df.iloc[:,ecg_feature_first:ecg_feature_last-2]
    else:
        ecg_feature = features
        # ecg_feature.append(label_source)
        ecg_df = df[ecg_feature]
    
    if cleaned_up:
        ecg_df = ecg_df.join(df[label_source]).dropna()
        return np.asarray(ecg_df.iloc[:,0:-1]), np.asarray(ecg_df.iloc[:,-1])
    else:
        return ecg_df

In [None]:
def multimodal_features(df, label_source = 'Steering_Wheel_Degree_Encoded'):
    
    all_features_list = [eeg_features(df), eye_features(df), ecg_features(df), df[label_source]]
    all_features_df = pd.concat(all_features_list, axis = 1).dropna()
    
    return np.asarray(all_features_df.iloc[:,0:-1]), np.asarray(all_features_df.iloc[:,-1])

In [None]:
def norm_features(x_train, x_test):
    
    scaler = MinMaxScaler()
    scaler.fit(x_train)
    x_train_norm = scaler.transform(x_train)
    x_test_norm = scaler.transform(x_test)
    
    return x_train_norm, x_test_norm

In [None]:
def feature_normalization(x_data, y_label, train_percentage=0.8):
    
    # Remove rows with invalid pupil diameter
    if sum(sum(np.isnan(x_data))) > 0:
        invalid_trial = np.argwhere(np.any(np.isnan(x_data) == True, axis=1))
        x_data_corrected = np.delete(x_data, invalid_trial, axis=0)
        y_label_corrected = np.delete(y_label, invalid_trial, axis=0)

    else:
        x_data_corrected = x_data
        y_label_corrected = y_label
    
    x_train, x_test, y_train, y_test = train_test_split(x_data_corrected, y_label_corrected, 
                                                                            train_size = train_percentage, random_state=rs)
    
    norm_data = MinMaxScaler().fit(x_train)
    x_train_norm = norm_data.transform(x_train)
    x_test_norm = norm_data.transform(x_test)

    return x_train_norm, x_test_norm, y_train, y_test


In [None]:
def modality_cv(x_modality, y_modality, n_folds = 10, classifier = 'logistic'):
    
    auc_list = np.empty((2, n_folds))
    
    skf = StratifiedKFold(n_splits = n_folds, random_state=rs, shuffle=True)

    for i, (train_index, test_index) in enumerate(skf.split(x_modality, y_modality)):

        x_train_norm, x_test_norm = norm_features(x_modality[train_index], x_modality[test_index])
        train_auc, test_auc, coefs = trial_classification(x_train_norm, x_test_norm,
                                                          y_modality[train_index], y_modality[test_index],
                                                          classifier, plot_fig = False)
        auc_list[0,i] = train_auc
        auc_list[1,i] = test_auc
    
    return np.mean(auc_list, axis = 1)


In [None]:
def calculate_rmse(df, modality, true_val_col = 'Steering_Wheel_Degree', features_list = 'all'):
    
    if modality == "EEG":
        x_modality, y_modality = eeg_features(df, features = features_list, label_source = true_val_col, 
                                              cleaned_up = True)
    if modality == "Eye":
        if features_list == 'all':
            features_list = ["Left Pupil Diameter", "Right Pupil Diameter",
                            "Left Evoked Pupil Diameter", "Right Evoked Pupil Diameter"]
        x_modality, y_modality = eye_features(df, features = features_list, label_source = true_val_col, 
                                              cleaned_up = True)
    if modality == "ECG":
        x_modality, y_modality = ecg_features(df, features = features_list, label_source = true_val_col, 
                                              cleaned_up = True)
    if modality == "All":
        x_modality, y_modality = multimodal_features(df, label_source = true_val_col)
        

    X_train, X_test, y_train, y_test = train_test_split(x_modality, y_modality, test_size=0.3, random_state=rs)
    regr.fit(X_train, y_train)

    y_pred = regr.predict(X_test)
    modality_rmse = mean_squared_error(y_test, y_pred, squared=False)
    
    return y_test, y_pred, modality_rmse

# 3. Process, Save, and Load Data

## Process file function

In [None]:
def process_files(template_ica = None, each_file):

    input_path = data_dir + each_file

    sbj_id = each_file[each_file.find('Sbj_')+4:each_file.find('-Ssn')]
    ssn_no = each_file[each_file.find('Ssn_')+4:each_file.find('.dats')]

    if len(sbj_id) < 2: sbj = "sbj0"+sbj_id
    else: sbj = "sbj"+sbj_id
    if len(ssn_no) < 2: ssn = "ssn0"+ssn_no
    else: ssn = "ssn"+ssn_no

    if template_ica: ref_ica = template_ica
    else: ref_ica = None

    with open(input_path, 'rb') as handle:
        rns_data = pickle.load(handle)

    ## Add metadata to data
    for key in rns_data.keys():
        rns_data[key].append(return_metadata_from_name(key, metadata_jsons))

    event_df = read_event_data(rns_data) # typically only 15_1 and 22_1 will be used here, change below too

    if event_df.empty:
        return None

    event_df = event_df[event_df.block_condition == 'voice']
    event_df['trial_damage'] = event_df.damage.diff().fillna(0)
    event_df['trial_duration'] = event_df.trial_end_time - event_df.trial_start_time

    percent_missing = event_df.notnull().sum() / len(event_df)
    summary_statistics = {}
    summary_statistics['voice_success_rate'] = percent_missing['spoken_difficulty']
    event_df['spoken_difficulty'] = event_df['spoken_difficulty'].fillna("unknown")
    event_df['spoken_difficulty_encoded'] = event_df.spoken_difficulty.replace(to_replace=['easy', 'hard', 'unknown'],
                                                                          value=[1, 2, 0])

    # ecg
    post_processed_event_df = process_session_ecg(rns_data, event_df,plot_frequency=20,plot_ecg_snippet=40)

    # eye
    if 'Unity_ViveSREyeTracking' in rns_data:
        post_processed_event_df = process_session_eye(rns_data, post_processed_event_df,detect_blink=True,
                                                      pretrial_period=0, posttrial_period=0, plot_frequency=20, 
                                                      plot_eye_snippet=40, classifiers=['NSLR'])

    # eeg
    post_processed_event_df, epochs, events, info, reject_log, ica, eog_idx= process_session_eeg(rns_data, post_processed_event_df,
                                run_autoreject=True, run_ica=True, save_raw_eeg = True, sbj_session = sbj+ssn, 
                                template_ica = ref_ica, analyze_pre_ica = True)

    # motor
    post_processed_event_df, turns_df = process_session_motor(rns_data, post_processed_event_df, motor_channel='Unity_MotorInput',
                                                plot_motor_result = True, plot_motor_snippet = 30, plot_frequency = 10)

    # save data for later use
    if save_data_pkl:

        with open(f'{pickle_dir}all_events.pickle', 'wb') as handle_events:
            pickle.dump(events, handle_events, protocol=pickle.HIGHEST_PROTOCOL)
        with open(f'{pickle_dir}ica_epochs.pickle', 'wb') as handle_ica_eps:
            pickle.dump(epochs, handle_ica_eps, protocol=pickle.HIGHEST_PROTOCOL)
        with open(f'{pickle_dir}ica.pickle', 'wb') as handle_ica:
            pickle.dump(ica, handle_ica, protocol=pickle.HIGHEST_PROTOCOL)
        with open(f'{pickle_dir}eog_comp.pickle', 'wb') as handle_eog:
            pickle.dump(eog_idx, handle_eog, protocol=pickle.HIGHEST_PROTOCOL)
            
    return post_processed_event_df, events, epochs, ica, eog_idx


## Multiprocessing files

In [None]:
# Process data for first participant and identify components need to be removed - template ica
results = []
result = process_file(each_file = onlyfiles[0])
template_ica = result[3]

# Multiprocessing 
cpu_no = 4
multi_process_files = False

if multi_process_files:
    with Pool(cpu_no) as p:
        results = p.map(partial(process_file, template_ica), onlyfiles)
else:
    for only in onlyfiles[1:]:
        result = process_file(template_ica, onlyfile)
        if result:
            results.append(result)
            
all_dfs = pd.concat([r[0] for r in results], ignore_index=True)

In [None]:
# Save dataframe

all_dfs.to_csv(f"{cvs_xlsx_dir}all_results.csv")
all_dfs.to_excel(f"{cvs_xlsx_dir}all_results.xlsx")

## Processing Raw EEG Data (Optional)

In [None]:
# load raw eeg.fif file and epoch raw eeg

if epoch_raw_eeg:

    with open(f'{pickle_dir}all_events.pickle', 'rb') as handle:
        all_events = pickle.load(handle)

    raw_eeg_dir = f'{output_dir}saved_files/raw_eeg/'
    event_dict = dict(easy=1, hard=2)

    raw_eeg_dict = {}
    raw_epochs_dict = {}

    for sbj_ssn in list(all_events.keys()):

        each_raw_eeg = sbj_ssn + '_eeg_filt_raw.fif'
        raw_eeg_path = raw_eeg_dir+each_raw_eeg
        raw_eeg = mne.io.read_raw_fif(raw_eeg_path, preload=True)
        raw_eeg_dict[sbj_ssn] = raw_eeg

        epochs_raw = mne.Epochs(raw_eeg, all_events[sbj_ssn], event_id=event_dict, baseline = (None, 0), tmin= -.2, tmax=3, preload=True, on_missing='warn')

        autoreject_epochs = 20
        run_autoreject = True

        if len(epochs_raw) < 10: # we need at least 10 epochs to run autoreject for cross validation
            # bad_epochs_raw = pd.Series(np.full(len(event_df),np.NAN), index=event_df.index, name='autorejected')
            # event_df = event_df.join(bad_epochs)
            reject_log = None
        elif run_autoreject:
            ar_raw = autoreject.AutoReject(random_state=rs,n_jobs=1, verbose=False)
            ar_raw.fit(epochs_raw[:autoreject_epochs])  # fit on a few epochs to save time
            epochs_ar, reject_log = ar_raw.transform(epochs_raw, return_log=True)
            # bad_epochs = pd.Series(reject_log.bad_epochs, index=event_recognized_df.index, dtype=bool, name='autorejected')
            # event_df = event_df.join(bad_epochs_raw) # creates nan if not processed at all
            epochs_raw = epochs_ar

        raw_epochs_dict[sbj_ssn] = epochs_raw

    with open(f'{pickle_dir}raw_epochs.pickle', 'wb') as handle_raw_eps:
        pickle.dump(raw_epochs_dict, handle_raw_eps, protocol=pickle.HIGHEST_PROTOCOL)
    with open(f'{pickle_dir}raw_eeg.pickle', 'wb') as handle_raw_eeg:
        pickle.dump(raw_eeg_dict, handle_raw_eeg, protocol=pickle.HIGHEST_PROTOCOL)


## Load data

In [None]:
all_dfs = pd.read_csv("../output/saved_files/corrected_voice_timestamp/all_results.csv")

# open saved pickle files
with open(f'{pickle_dir}ica_epochs.pickle', 'rb') as handle:
    all_proc_epochs = pickle.load(handle)
with open(f'{pickle_dir}ica.pickle', 'rb') as handle:
    all_ica = pickle.load(handle)
with open(f'{pickle_dir}eog_comp.pickle', 'rb') as handle:
    all_eog_comps = pickle.load(handle)
with open(f'{pickle_dir}raw_epochs.pickle', 'rb') as handle:
    all_raw_epochs = pickle.load(handle)

# save ICA components plot
if save_ica_plts:
    ica_comp_dir = "../output/plots/ica_comps/"
    if not os.path.isdir(ica_comp_dir): os.makedirs(ica_comp_dir)

    for sbj_ssn in list(all_ica.keys()):
        
        all_ica[sbj_ssn].plot_components(picks = list(range(0,20)), title=sbj_ssn+"_ICA_Components", show=False)

        plt.savefig(f"{ica_comp_dir}{sbj_ssn}_ica_comps.png")
        plt.close()

In [None]:
# Removed component for all sessions

show_removed_comp = False

if show_removed_comp:
    for sbj in all_ica.keys():
        if all_eog_comps[sbj] != []:

            all_ica[sbj].plot_components(picks = all_eog_comps[sbj], title=sbj, show=False)

            plt.savefig(f"{output_dir}/plots/Removed_Components_Corrmap/{sbj}_removed_components.png")
            plt.close()
        else:
            pass
    
# all_eog_comps.values()