In [1]:
import numpy as np
import pickle

# mne imports
import mne
from mne import io
from mne.datasets import sample

# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression

# tools for plotting confusion matrices
from matplotlib import pyplot as plt

In [2]:
# Preprocessing & results----------------
from sklearn.model_selection import train_test_split, cross_validate, cross_val_score, GridSearchCV
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Models-------------------------
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression, RidgeClassifier, SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.gaussian_process import GaussianProcessClassifier
import sklearn.gaussian_process.kernels as kls
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier, BaggingClassifier, ExtraTreesClassifier

# General purpose
import re
import pandas as pd
import pickle
import numpy as np
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

In [3]:
clf_dict = {
    'DecisionTree': {"model": DecisionTreeClassifier(random_state=42), "params": {'max_depth': list(range(2, 16, 3))}},
    'RandomForest': {"model": RandomForestClassifier(random_state=42),
                     "params": {'n_estimators': list(range(5, 100, 5)), 'max_depth': list(range(2, 16, 3))}},
    'LogisticR_L1': {"model": LogisticRegression(random_state=42, max_iter=1000),
                     "params": {'penalty': ['l1'], 'solver': ['liblinear', 'saga']}},
    'LogisticR_L2': {"model": LogisticRegression(random_state=42, max_iter=1000),
                     "params": {'penalty': ['l2'], 'solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']}},
    'LogisticR': {"model": LogisticRegression(random_state=42, max_iter=1000),
                  "params": {'penalty': ['none'], 'solver': ['newton-cg', 'lbfgs', 'sag', 'saga']}},
    'RidgeClf': {"model": RidgeClassifier(max_iter=1000), "params": {}},
    'SVC_linear': {"model": SVC(random_state=42), "params": {'kernel': ['linear'], 
                                                             'C': [0.5, 1.0, 1.5, 2.0, 2.5]}},
    'SVC_poly': {"model": SVC(random_state=42),
                 "params": {'kernel': ['poly'], 'degree': [3, 4, 5], 'gamma': ['scale', 'auto'], 
                            'C': [0.5, 1.0, 1.5, 2.0, 2.5]}},
    'SVC_others': {"model": SVC(random_state=42), "params": {'kernel': ['rbf', 'sigmoid'], 
                                                             'gamma': ['scale', 'auto'], 
                                                             'C': [0.5, 1.0, 1.5, 2.0, 2.5]}},
    'GussianNB': {"model": GaussianNB(), "params": {}},
    'KNN': {"model": KNeighborsClassifier(), "params": {'n_neighbors': list(range(3, 30))}},
    # 'GaussianProcessClf': {"model": GaussianProcessClassifier(random_state=42, kernel=kls.RBF()), "params": {}},
    'Bagging_SVC': {"model": BaggingClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                            'base_estimator': [SVC(kernel='linear'),
                                                                                               SVC(kernel='poly',
                                                                                                   degree=3,
                                                                                                   gamma='scale')]}},
    'BaggingDT': {"model": BaggingClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                          'base_estimator': [
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=2),
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=5),
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=10)]}},
    'AdaBoost': {"model": AdaBoostClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                          'base_estimator': [DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=2),
                                                                                             DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=5),
                                                                                             DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=10)]}},
    'ExtraTrees': {"model": ExtraTreesClassifier(random_state=42), "params": {'n_estimators': list(range(5, 105, 30)), 
                                                                              'max_depth': [2,5,10,15]}},
    'MLP_l1': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x,) for x in 
                                                                                          range(50, 600, 200)], 
                                                                  'activation': ['logistic', 'tanh', 'relu'],
                                                                  'solver': ['adam', 'sgd'], 'early_stopping': 
                                                                   [True]}},
    'MLP_l2': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x, y) for x in 
                                                                                          range(50, 600, 200) 
                                                                                          for y in range(50, 360, 200)], 
                                                                  'activation': ['logistic', 'tanh', 'relu'],
                                                                  'solver': ['adam', 'sgd'], 'early_stopping': 
                                                                                               [True]}},
    'MLP_l5': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x, y, z, a) for x in range(300, 600, 200) 
                                                                                          for y in range(300, 600, 200)
                                                                                          for z in range(100, 300, 100)
                                                                                          for a in [20, 40]], 
                                                                  'activation': ['logistic', 'tanh', 'relu'],
                                                                  'solver': ['adam', 'sgd'], 'early_stopping': 
                                                                                               [True]}},
    }


In [4]:
X_train_df = pd.read_csv('./data/train_resnet18emb.csv', index_col=False)
X_test_df = pd.read_csv('./data/test_resnet18emb.csv', index_col=False)
X_train_df.set_index('Unnamed: 0', inplace=True)
X_train_df.reset_index(drop=True, inplace=True)
X_test_df.set_index('Unnamed: 0', inplace=True)
X_test_df.reset_index(drop=True, inplace=True)

In [5]:
X_train_df.shape, X_test_df.shape

((2057, 513), (615, 513))

In [6]:
X_train_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,503,504,505,506,507,508,509,510,511,y
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2


In [7]:
X_test_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,503,504,505,506,507,508,509,510,511,y
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0


In [8]:
y_train_medid = X_train_df['y']
y_test_medid = X_test_df['y']
y_train_medid.head()

0    3
1    1
2    0
3    0
4    2
Name: y, dtype: int64

In [9]:
y_train_medid.unique(), y_test_medid.unique()

(array([3, 1, 0, 2]), array([1, 3, 2, 0]))

In [10]:
X_train_df = X_train_df.drop('y', 1)
X_test_df = X_test_df.drop('y', 1)

In [12]:
# X_train, y_train_medid, y_train_subjid, X_val, y_val_medid, y_val_subjid, X_test, y_test_medid, y_test_subjid
model_results = pd.DataFrame()
model_results['Train_Accuracy'] = None
model_results['Test_Accuracy'] = None
model_results['best_params'] = None

all_models = {}

best_clf_ours = None
best_clf_val = 0

for clf_name, clf in clf_dict.items():
    classifier = GridSearchCV(clf['model'], clf['params'], n_jobs=10)
    classifier.fit(X_train_df, y_train_medid)
    best_model = classifier.best_estimator_

    y_predicted = best_model.predict(X_test_df)
    test_acc = accuracy_score(y_test_medid, y_predicted)
    
    if test_acc > best_clf_val:
        best_clf_val = test_acc
        best_clf_ours = best_model
        
    
    print(clf_name, classifier.best_score_, classifier.best_params_, test_acc)
    all_models[clf_name] = best_model

    model_results.loc[clf_name, ['Train_Accuracy', 'Test_Accuracy', 'best_params']] = [classifier.best_score_, test_acc, classifier.best_params_]
    clsr = classification_report(y_test_medid, y_predicted)

print("================================================================================")
print(best_clf_ours)
best_y_hat = best_clf_ours.predict(X_test_df)
clsr = classification_report(y_test_medid, best_y_hat)
print(clsr)
test_acc = accuracy_score(y_test_medid, best_y_hat)
print("Test acc:", test_acc )
print("Weighted F1 score: ", f1_score(y_test_medid, best_y_hat, average='weighted'))

DecisionTree 0.26932889235348306 {'max_depth': 2} 0.23902439024390243
RandomForest 0.28683414830038034 {'max_depth': 11, 'n_estimators': 45} 0.24715447154471545




LogisticR_L1 0.2654359483145537 {'penalty': 'l1', 'solver': 'liblinear'} 0.24065040650406505
LogisticR_L2 0.2761356388632981 {'penalty': 'l2', 'solver': 'lbfgs'} 0.25365853658536586


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

LogisticR 0.2702938605815794 {'penalty': 'none', 'solver': 'sag'} 0.25040650406504067
RidgeClf 0.2727269506059103 {} 0.25365853658536586
SVC_linear 0.265433586091229 {'C': 0.5, 'kernel': 'linear'} 0.25203252032520324
SVC_poly 0.263979637634942 {'C': 2.0, 'degree': 5, 'gamma': 'scale', 'kernel': 'poly'} 0.25203252032520324
SVC_others 0.2649493303096875 {'C': 2.0, 'gamma': 'auto', 'kernel': 'rbf'} 0.24878048780487805
GussianNB 0.2668946212174899 {} 0.25040650406504067
KNN 0.2615595398388964 {'n_neighbors': 3} 0.27479674796747966
Bagging_SVC 0.2649469680863629 {'base_estimator': SVC(kernel='linear'), 'n_estimators': 25} 0.24878048780487805
BaggingDT 0.2795289726690761 {'base_estimator': DecisionTreeClassifier(max_depth=2, random_state=42), 'n_estimators': 45} 0.24715447154471545
AdaBoost 0.26738832589232986 {'base_estimator': DecisionTreeClassifier(max_depth=10, random_state=42), 'n_estimators': 5} 0.25203252032520324
ExtraTrees 0.2717608012661517 {'max_depth': 10, 'n_estimators': 5} 0.23



MLP_l1 0.2688434554602792 {'activation': 'tanh', 'early_stopping': True, 'hidden_layer_sizes': (450,), 'solver': 'adam'} 0.25040650406504067




MLP_l2 0.2693300734651454 {'activation': 'relu', 'early_stopping': True, 'hidden_layer_sizes': (50, 50), 'solver': 'adam'} 0.25203252032520324




MLP_l5 0.272243875936031 {'activation': 'relu', 'early_stopping': True, 'hidden_layer_sizes': (300, 300, 200, 20), 'solver': 'adam'} 0.25203252032520324
KNeighborsClassifier(n_neighbors=3)
              precision    recall  f1-score   support

           0       0.28      0.49      0.35       153
           1       0.25      0.21      0.23       156
           2       0.28      0.12      0.17       156
           3       0.30      0.28      0.29       150

    accuracy                           0.27       615
   macro avg       0.27      0.28      0.26       615
weighted avg       0.27      0.27      0.26       615

Test acc: 0.27479674796747966
Weighted F1 score:  0.25895303906174266


In [13]:
model_results

Unnamed: 0,Train_Accuracy,Test_Accuracy,best_params
DecisionTree,0.269329,0.239024,{'max_depth': 2}
RandomForest,0.286834,0.247154,"{'max_depth': 11, 'n_estimators': 45}"
LogisticR_L1,0.265436,0.24065,"{'penalty': 'l1', 'solver': 'liblinear'}"
LogisticR_L2,0.276136,0.253659,"{'penalty': 'l2', 'solver': 'lbfgs'}"
LogisticR,0.270294,0.250407,"{'penalty': 'none', 'solver': 'sag'}"
RidgeClf,0.272727,0.253659,{}
SVC_linear,0.265434,0.252033,"{'C': 0.5, 'kernel': 'linear'}"
SVC_poly,0.26398,0.252033,"{'C': 2.0, 'degree': 5, 'gamma': 'scale', 'ker..."
SVC_others,0.264949,0.24878,"{'C': 2.0, 'gamma': 'auto', 'kernel': 'rbf'}"
GussianNB,0.266895,0.250407,{}


In [34]:
model_results

Unnamed: 0,Train_Accuracy,Test_Accuracy,best_params
DecisionTree,1.0,1.0,{'max_depth': 5}
RandomForest,0.956255,0.977236,"{'max_depth': 14, 'n_estimators': 95}"
LogisticR_L1,1.0,1.0,"{'penalty': 'l1', 'solver': 'saga'}"
LogisticR_L2,1.0,1.0,"{'penalty': 'l2', 'solver': 'newton-cg'}"
LogisticR,1.0,1.0,"{'penalty': 'none', 'solver': 'newton-cg'}"
RidgeClf,0.510936,0.495935,{}
SVC_linear,1.0,1.0,"{'C': 0.5, 'kernel': 'linear'}"
SVC_poly,1.0,1.0,"{'C': 0.5, 'degree': 3, 'gamma': 'scale', 'ker..."
SVC_others,1.0,1.0,"{'C': 1.5, 'gamma': 'auto', 'kernel': 'rbf'}"
GussianNB,1.0,1.0,{}


In [None]:
print("Done!!!")

Done!!!
