In [1]:
# Preprocessing & results----------------
from sklearn.model_selection import train_test_split, cross_validate, cross_val_score, GridSearchCV
from sklearn.metrics import classification_report, accuracy_score, f1_score, balanced_accuracy_score
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Models-------------------------
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression, RidgeClassifier, SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.gaussian_process import GaussianProcessClassifier
import sklearn.gaussian_process.kernels as kls
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier, BaggingClassifier, ExtraTreesClassifier

# General purpose
import re
import pandas as pd
import pickle
import numpy as np
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

In [2]:
data_type = 'mind-wandering'

In [3]:
def load_dataset(user_split = True, runidx=1, data_type='mind-wandering'):
    X_train = None
    y_train_medid = None
    y_train_subjid = None
    X_val = None
    y_val_medid = None
    y_val_subjid = None
    X_test = None
    y_test_medid = None
    y_test_subjid = None

    # user_split: Determines the creation of train/val/test set. In user_split, test/val users are never seen during the train. time_split randomly splits each user's chunks into train/val/test.
    if user_split:
        data_file_path = '../../iconip_data/{}/user_based_splits_with_timestamp_RUN{}.pkl'.format(data_type, runidx)
    else:
        data_file_path = '../../iconip_data/{}/time_based_splits_with_timestamp_RUN{}.pkl'.format(data_type, runidx)

    with open(data_file_path, 'rb') as f:
        all_data_splits = pickle.load(f)
        
        X_train = all_data_splits['train']['x']
        y_train_medid = all_data_splits['train']['y_med']
        y_train_subjid = all_data_splits['train']['y_subj']
        y_train_ts = all_data_splits['train']['y_ts']

        X_val = all_data_splits['val']['x']
        y_val_medid = all_data_splits['val']['y_med']
        y_val_subjid = all_data_splits['val']['y_subj']
        y_val_ts = all_data_splits['val']['y_ts']

        X_test = all_data_splits['test']['x']
        y_test_medid = all_data_splits['test']['y_med']
        y_test_subjid = all_data_splits['test']['y_subj']
        y_test_ts = all_data_splits['test']['y_ts']

    return X_train, y_train_medid, y_train_subjid, y_train_ts, X_val, y_val_medid, y_val_subjid, y_val_ts, X_test, y_test_medid, y_test_subjid, y_test_ts

In [4]:
clf_dict = {
    'DecisionTree': {"model": DecisionTreeClassifier(random_state=42), "params": {'max_depth': list(range(2, 64, 10))}},
    'RandomForest': {"model": RandomForestClassifier(random_state=42),
                     "params": {'n_estimators': list(range(5, 100, 20)), 'max_depth': list(range(2, 64, 10))}},
    'LogisticR_L1': {"model": LogisticRegression(random_state=42, max_iter=500),
                     "params": {'penalty': ['l1'], 'solver': ['liblinear', 'saga']}},
    'LogisticR_L2': {"model": LogisticRegression(random_state=42, max_iter=1000),
                     "params": {'penalty': ['l2'], 'solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']}},
    'LogisticR': {"model": LogisticRegression(random_state=42, max_iter=500),
                  "params": {'penalty': ['none'], 'solver': ['newton-cg', 'lbfgs', 'sag', 'saga']}},
    'RidgeClf': {"model": RidgeClassifier(max_iter=1000), "params": {}},
    'SVC_linear': {"model": SVC(random_state=42), "params": {'kernel': ['linear'], 
                                                             'C': [0.5, 1.5, 2.5]}},
    'SVC_poly': {"model": SVC(random_state=42),
                 "params": {'kernel': ['poly'], 'degree': [3, 4, 5], 'gamma': ['scale', 'auto'], 
                            'C': [0.5, 1.5, 2.5]}},
    'SVC_others': {"model": SVC(random_state=42), "params": {'kernel': ['rbf', 'sigmoid'], 
                                                             'gamma': ['scale', 'auto'], 
                                                             'C': [0.5, 1.5, 2.5]}},
    'GussianNB': {"model": GaussianNB(), "params": {}},
    'KNN': {"model": KNeighborsClassifier(), "params": {'n_neighbors': list(range(3, 30))}},
# #     'GaussianProcessClf': {"model": GaussianProcessClassifier(random_state=42, kernel=kls.RBF()), "params": {}},
    'Bagging_SVC': {"model": BaggingClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                            'base_estimator': [SVC(kernel='linear'),
                                                                                               SVC(kernel='poly',
                                                                                                   degree=3,
                                                                                                   gamma='scale')]}},
    'BaggingDT': {"model": BaggingClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                          'base_estimator': [
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=2),
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=5),
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=10)]}},
    'AdaBoost': {"model": AdaBoostClassifier(random_state=42), "params": {'n_estimators': list(range(10, 100, 20)),
                                                                          'base_estimator': [DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=2),
                                                                                             DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=12),
                                                                                             DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=12)]}},
    'ExtraTrees': {"model": ExtraTreesClassifier(random_state=42), "params": {'n_estimators': list(range(5, 105, 30)), 
                                                                              'max_depth': [2,5,10,15]}},
    'MLP_l1': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x,) for x in 
                                                                                          range(50, 600, 200)], 
                                                                  'activation': ['logistic', 'tanh', 'relu'],
                                                                  'solver': ['adam', 'sgd'], 'early_stopping': 
                                                                   [True]}},
    'MLP_l2': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x, y) for x in 
                                                                                          range(50, 600, 200) 
                                                                                          for y in range(50, 360, 200)], 
                                                                  'activation': ['logistic', 'tanh', 'relu'],
                                                                  'solver': ['adam', 'sgd'], 'early_stopping': 
                                                                                               [True]}},
#     'MLP_l3': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x, y, z) for x in 
#                                                                                           range(50, 600, 100) 
#                                                                                           for y in range(50, 600, 100)
#                                                                                           for z in range(50, 360, 100)], 
#                                                                   'activation': ['logistic', 'tanh', 'relu'],
#                                                                   'solver': ['adam', 'sgd'], 'early_stopping': 
#                                                                                                [True]}},
}


In [8]:
def predict_med_control_vs_expert(data_type):
    model_results = pd.DataFrame()

    for runidx in range(1, 6):
        model_results['{}_Train_Acc'.format(runidx)] = None
        model_results['{}_Val_Acc'.format(runidx)] = None
        model_results['{}_Test_Acc'.format(runidx)] = None
        model_results['{}_best_params'.format(runidx)] = None

    best_clf_ours = None
    best_clf_val = 0

    for runidx in range(1, 6):
        print("RUN ID: ", runidx)
        user_split   = True
        med_tech_clf = True
        
        X_train, y_train_medid, y_train_subjid, y_train_ts, X_val, y_val_medid, y_val_subjid, y_val_ts, X_test, y_test_medid, y_test_subjid, y_test_ts = load_dataset(user_split=user_split, runidx=runidx, data_type=data_type)
        ## Convert data to binary classification task, i.e. meditation expert vs control group
        y_train_medid = y_train_medid//3
        y_val_medid = y_val_medid//3
        y_test_medid = y_test_medid//3

        # Take data average to reshape to just 64 dim vecs
        X_train_list = []
        X_val_list = []
        X_test_list = []

        for i in range(X_train.shape[0]):
            X_train_list.append(np.mean(X_train[i,:,:], axis=1))
        X_train_list = np.array(X_train_list)
        #print(X_train.shape, X_train_list.shape)

        for i in range(X_val.shape[0]):
            X_val_list.append(np.mean(X_val[i,:,:], axis=1))
        X_val_list = np.array(X_val_list)
        #print(X_val.shape, X_val_list.shape)

        for i in range(X_test.shape[0]):
            X_test_list.append(np.mean(X_test[i,:,:], axis=1))
        X_test_list = np.array(X_test_list)
        #print(X_test.shape, X_test_list.shape)

        # For keeping only 64 channels data
        X_train_df = pd.DataFrame(X_train_list, columns=[str(x) for x in range(0, 64)])
        X_val_df = pd.DataFrame(X_val_list, columns=[str(x) for x in range(0, 64)])
        X_test_df = pd.DataFrame(X_test_list, columns=[str(x) for x in range(0, 64)])

        for clf_name, clf in clf_dict.items():
            classifier = GridSearchCV(clf['model'], clf['params'], n_jobs=20, verbose=0)
            classifier.fit(X_train_df, y_train_medid)
            best_model = classifier.best_estimator_
            
            y_predicted = best_model.predict(X_val_df)
            val_acc = accuracy_score(y_val_medid, y_predicted)
            bal_val_acc = balanced_accuracy_score(y_val_medid, y_predicted)
            
            if bal_val_acc > best_clf_val:
                best_clf_val = bal_val_acc
                best_clf_ours = best_model
                
            y_predicted = best_model.predict(X_test_df)
            test_acc = balanced_accuracy_score(y_test_medid, y_predicted)
            
            print(clf_name, classifier.best_score_, classifier.best_params_, val_acc, bal_val_acc, test_acc)

            model_results.loc[clf_name, ['{}_Train_Acc'.format(runidx), '{}_Val_Acc'.format(runidx), '{}_Test_Acc'.format(runidx), '{}_best_params'.format(runidx)]] = [classifier.best_score_, bal_val_acc, test_acc, classifier.best_params_]
            if user_split:
                model_results.to_csv('./RESULTS_T2{}_MED_BINARYPRED_{}.csv'.format('USER', data_type))
            else:
                model_results.to_csv('./RESULTS_T2_MED_BINARYPRED_{}.csv'.format(data_type))

    return model_results

In [9]:
model_results = predict_med_control_vs_expert(data_type)

RUN ID:  1
DecisionTree 0.9896733276043621 {'max_depth': 12} 0.5884194053208138 0.3906793589919189 0.49789688739334215
RandomForest 0.9990171990171991 {'max_depth': 12, 'n_estimators': 25} 0.7323943661971831 0.4775510204081633 0.4596602972399151




LogisticR_L1 0.9985233778337227 {'penalty': 'l1', 'solver': 'liblinear'} 0.4475743348982786 0.29183673469387755 0.43177903296879383
LogisticR_L2 0.9985233778337227 {'penalty': 'l2', 'solver': 'liblinear'} 0.5446009389671361 0.39480208190658816 0.5521371630012418




LogisticR 0.9990159886711611 {'penalty': 'none', 'solver': 'newton-cg'} 0.539906103286385 0.3940761539515135 0.6052157192645115
RidgeClf 0.9990147783251231 {} 0.3270735524256651 0.2202711957266128 0.516183952249329
SVC_linear 0.9985245881797606 {'C': 0.5, 'kernel': 'linear'} 0.568075117370892 0.4451376523763868 0.8545447261947683
SVC_poly 0.9990147783251231 {'C': 0.5, 'degree': 3, 'gamma': 'auto', 'kernel': 'poly'} 0.6369327073552425 0.41997671551842214 0.6220406201177743
SVC_others 0.9985245881797606 {'C': 0.5, 'gamma': 'auto', 'kernel': 'rbf'} 0.7652582159624414 0.5176619641145048 0.5198894363658214
GussianNB 0.736977281804868 {} 0.24100156494522693 0.2669017942747569 0.5193286063373793
KNN 1.0 {'n_neighbors': 5} 0.6494522691705791 0.5378989179564443 0.7708208148059128
Bagging_SVC 0.9985245881797606 {'base_estimator': SVC(kernel='linear'), 'n_estimators': 25} 0.568075117370892 0.4451376523763868 0.8597324039578577
BaggingDT 0.9950847847399572 {'base_estimator': DecisionTreeClassifier

























MLP_l1 0.9970491763595213 {'activation': 'tanh', 'early_stopping': True, 'hidden_layer_sizes': (450,), 'solver': 'adam'} 0.5242566510172144 0.3511779208327626 0.5718663622160798








































































MLP_l2 0.9985245881797606 {'activation': 'tanh', 'early_stopping': True, 'hidden_layer_sizes': (50, 250), 'solver': 'adam'} 0.48982785602503914 0.32172305163676207 0.49995994071225414
RUN ID:  2
DecisionTree 0.9906403940886699 {'max_depth': 12} 0.6925515055467512 0.7458452367871662 0.4924215084806929
RandomForest 0.9995073891625615 {'max_depth': 12, 'n_estimators': 5} 0.6719492868462758 0.6806257554114932 0.4638117005493404
LogisticR_L1 0.9990147783251231 {'penalty': 'l1', 'solver': 'saga'} 0.6117274167987322 0.4658004614877486 0.6848710854484943
LogisticR_L2 0.9995073891625615 {'penalty': 'l2', 'solver': 'liblinear'} 0.6307448494453248 0.48955471926161964 0.6658446609727736
LogisticR 0.9995073891625615 {'penalty': 'none', 'solver': 'sag'} 0.6259904912836767 0.4751950335128008 0.671979630297927
RidgeClf 0.9995073891625615 {} 0.6434231378763867 0.43278211185584 0.44213881871767113
SVC_linear 0.9990147783251231 {'C': 0.5, 'kernel': 'linear'} 0.5404120443740095 0.42107323371058125 0.47690





























MLP_l1 0.9985221674876847 {'activation': 'tanh', 'early_stopping': True, 'hidden_layer_sizes': (450,), 'solver': 'adam'} 0.606973058637084 0.43347571695418086 0.5588836761698545






































































MLP_l2 0.9980295566502463 {'activation': 'tanh', 'early_stopping': True, 'hidden_layer_sizes': (50, 50), 'solver': 'adam'} 0.7353407290015848 0.6392910119767059 0.4801515698303862
RUN ID:  3
DecisionTree 0.9960505990391049 {'max_depth': 12} 0.5142405063291139 0.3456959706959707 0.47492041307554933
RandomForest 1.0 {'max_depth': 12, 'n_estimators': 25} 0.6882911392405063 0.46339689722042665 0.5524303129124932
LogisticR_L1 1.0 {'penalty': 'l1', 'solver': 'liblinear'} 0.4699367088607595 0.4240196078431372 0.5882444289152885
LogisticR_L2 1.0 {'penalty': 'l2', 'solver': 'liblinear'} 0.5949367088607594 0.4962292609351433 0.4642052954421927
LogisticR 1.0 {'penalty': 'none', 'solver': 'newton-cg'} 0.5664556962025317 0.5031781943546649 0.5019993788337603
RidgeClf 0.9990147783251231 {} 0.4936708860759494 0.39237233354880413 0.2736625514403292
SVC_linear 1.0 {'C': 0.5, 'kernel': 'linear'} 0.5063291139240507 0.43094160741219567 0.5102298315086575
SVC_poly 1.0 {'C': 0.5, 'degree': 3, 'gamma': 'auto





























MLP_l1 0.9985209511646292 {'activation': 'tanh', 'early_stopping': True, 'hidden_layer_sizes': (450,), 'solver': 'adam'} 0.6455696202531646 0.45873734109028225 0.3374485596707819












































































MLP_l2 0.9995061728395062 {'activation': 'relu', 'early_stopping': True, 'hidden_layer_sizes': (50, 250), 'solver': 'adam'} 0.7420886075949367 0.6197748330101271 0.5410940290395216
RUN ID:  4
DecisionTree 0.9933174224343675 {'max_depth': 12} 0.706081081081081 0.5852872261798053 0.5361965681795107
RandomForest 0.9990453460620525 {'max_depth': 12, 'n_estimators': 25} 0.6908783783783784 0.4534368070953437 0.4234440044674586
LogisticR_L1 0.9990453460620525 {'penalty': 'l1', 'solver': 'saga'} 0.49493243243243246 0.32483370288248337 0.480759467966291
LogisticR_L2 0.9990453460620525 {'penalty': 'l2', 'solver': 'newton-cg'} 0.5844594594594594 0.3835920177383592 0.42461163569905574
LogisticR 0.9995226730310263 {'penalty': 'none', 'solver': 'newton-cg'} 0.46790540540540543 0.3070953436807095 0.24301959589806071
RidgeClf 0.9995226730310263 {} 0.5844594594594594 0.3835920177383592 0.33089653771956545
SVC_linear 0.9995226730310263 {'C': 0.5, 'kernel': 'linear'} 0.5726351351351351 0.3758314855875831























MLP_l1 0.9985680190930788 {'activation': 'relu', 'early_stopping': True, 'hidden_layer_sizes': (450,), 'solver': 'adam'} 0.49324324324324326 0.3261625072730418 0.5001522997258605






































































MLP_l2 0.9985680190930788 {'activation': 'tanh', 'early_stopping': True, 'hidden_layer_sizes': (50, 50), 'solver': 'adam'} 0.5827702702702703 0.38248337028824836 0.5679764443090669
RUN ID:  5
DecisionTree 0.9888136914463892 {'max_depth': 12} 0.6993569131832797 0.4640331266433262 0.677522306108442
RandomForest 1.0 {'max_depth': 12, 'n_estimators': 25} 0.752411575562701 0.49906497377708414 0.4746739876458476
LogisticR_L1 0.9980558902038599 {'penalty': 'l1', 'solver': 'saga'} 0.6543408360128617 0.5400444313212694 0.49152367879203845
LogisticR_L2 0.9995145631067961 {'penalty': 'l2', 'solver': 'liblinear'} 0.6543408360128617 0.5422941184741497 0.48939601921757037
LogisticR 0.9995145631067961 {'penalty': 'none', 'solver': 'newton-cg'} 0.6591639871382636 0.5522278933085868 0.4904598490048044
RidgeClf 0.9985413270970638 {} 0.7491961414790996 0.6116758763234488 0.4573781743308168
SVC_linear 0.9995145631067961 {'C': 0.5, 'kernel': 'linear'} 0.7556270096463023 0.6159221608245103 0.492587508579272























MLP_l1 0.9990279451019299 {'activation': 'relu', 'early_stopping': True, 'hidden_layer_sizes': (250,), 'solver': 'adam'} 0.7106109324758842 0.538954739106593 0.41595744680851066
































































MLP_l2 0.9995133819951338 {'activation': 'relu', 'early_stopping': True, 'hidden_layer_sizes': (450, 250), 'solver': 'adam'} 0.6736334405144695 0.4447983014861996 0.5


In [10]:
model_results

Unnamed: 0,1_Train_Acc,1_Val_Acc,1_Test_Acc,1_best_params,2_Train_Acc,2_Val_Acc,2_Test_Acc,2_best_params,3_Train_Acc,3_Val_Acc,3_Test_Acc,3_best_params,4_Train_Acc,4_Val_Acc,4_Test_Acc,4_best_params,5_Train_Acc,5_Val_Acc,5_Test_Acc,5_best_params
DecisionTree,0.989673,0.390679,0.497897,{'max_depth': 12},0.99064,0.745845,0.492422,{'max_depth': 12},0.996051,0.345696,0.47492,{'max_depth': 12},0.993317,0.585287,0.536197,{'max_depth': 12},0.988814,0.464033,0.677522,{'max_depth': 12}
RandomForest,0.999017,0.477551,0.45966,"{'max_depth': 12, 'n_estimators': 25}",0.999507,0.680626,0.463812,"{'max_depth': 12, 'n_estimators': 5}",1.0,0.463397,0.55243,"{'max_depth': 12, 'n_estimators': 25}",0.999045,0.453437,0.423444,"{'max_depth': 12, 'n_estimators': 25}",1.0,0.499065,0.474674,"{'max_depth': 12, 'n_estimators': 25}"
LogisticR_L1,0.998523,0.291837,0.431779,"{'penalty': 'l1', 'solver': 'liblinear'}",0.999015,0.4658,0.684871,"{'penalty': 'l1', 'solver': 'saga'}",1.0,0.42402,0.588244,"{'penalty': 'l1', 'solver': 'liblinear'}",0.999045,0.324834,0.480759,"{'penalty': 'l1', 'solver': 'saga'}",0.998056,0.540044,0.491524,"{'penalty': 'l1', 'solver': 'saga'}"
LogisticR_L2,0.998523,0.394802,0.552137,"{'penalty': 'l2', 'solver': 'liblinear'}",0.999507,0.489555,0.665845,"{'penalty': 'l2', 'solver': 'liblinear'}",1.0,0.496229,0.464205,"{'penalty': 'l2', 'solver': 'liblinear'}",0.999045,0.383592,0.424612,"{'penalty': 'l2', 'solver': 'newton-cg'}",0.999515,0.542294,0.489396,"{'penalty': 'l2', 'solver': 'liblinear'}"
LogisticR,0.999016,0.394076,0.605216,"{'penalty': 'none', 'solver': 'newton-cg'}",0.999507,0.475195,0.67198,"{'penalty': 'none', 'solver': 'sag'}",1.0,0.503178,0.501999,"{'penalty': 'none', 'solver': 'newton-cg'}",0.999523,0.307095,0.24302,"{'penalty': 'none', 'solver': 'newton-cg'}",0.999515,0.552228,0.49046,"{'penalty': 'none', 'solver': 'newton-cg'}"
RidgeClf,0.999015,0.220271,0.516184,{},0.999507,0.432782,0.442139,{},0.999015,0.392372,0.273663,{},0.999523,0.383592,0.330897,{},0.998541,0.611676,0.457378,{}
SVC_linear,0.998525,0.445138,0.854545,"{'C': 0.5, 'kernel': 'linear'}",0.999015,0.421073,0.476904,"{'C': 0.5, 'kernel': 'linear'}",1.0,0.430942,0.51023,"{'C': 0.5, 'kernel': 'linear'}",0.999523,0.375831,0.28145,"{'C': 0.5, 'kernel': 'linear'}",0.999515,0.615922,0.492588,"{'C': 0.5, 'kernel': 'linear'}"
SVC_poly,0.999015,0.419977,0.622041,"{'C': 0.5, 'degree': 3, 'gamma': 'auto', 'kern...",0.999507,0.599062,0.661554,"{'C': 1.5, 'degree': 3, 'gamma': 'auto', 'kern...",1.0,0.464205,0.451646,"{'C': 0.5, 'degree': 3, 'gamma': 'auto', 'kern...",0.998568,0.52454,0.332623,"{'C': 0.5, 'degree': 4, 'gamma': 'auto', 'kern...",0.999513,0.677191,0.593583,"{'C': 0.5, 'degree': 3, 'gamma': 'auto', 'kern..."
SVC_others,0.998525,0.517662,0.519889,"{'C': 0.5, 'gamma': 'auto', 'kernel': 'rbf'}",0.998522,0.710526,0.45092,"{'C': 0.5, 'gamma': 'auto', 'kernel': 'rbf'}",1.0,0.436786,0.440387,"{'C': 0.5, 'gamma': 'scale', 'kernel': 'rbf'}",0.999045,0.427938,0.473703,"{'C': 0.5, 'gamma': 'scale', 'kernel': 'rbf'}",0.999027,0.461783,0.531332,"{'C': 0.5, 'gamma': 'auto', 'kernel': 'rbf'}"
GussianNB,0.736977,0.266902,0.519329,{},0.737931,0.731932,0.188139,{},0.794167,0.34144,0.585624,{},0.763246,0.416592,0.398924,{},0.817606,0.717032,0.393651,{}
