
 Sample script using EEGNet to classify Event-Related Potential (ERP) EEG data
 from a four-class classification task, using the sample dataset provided in
 the MNE [1, 2] package:
     https://martinos.org/mne/stable/manual/sample_dataset.html#ch-sample-data
   
 The four classes used from this dataset are:
     LA: Left-ear auditory stimulation
     RA: Right-ear auditory stimulation
     LV: Left visual field stimulation
     RV: Right visual field stimulation

 The code to process, filter and epoch the data are originally from Alexandre
 Barachant's PyRiemann [3] package, released under the BSD 3-clause. A copy of 
 the BSD 3-clause license has been provided together with this software to 
 comply with software licensing requirements. 
 
 When you first run this script, MNE will download the dataset and prompt you
 to confirm the download location (defaults to ~/mne_data). Follow the prompts
 to continue. The dataset size is approx. 1.5GB download. 
 
 For comparative purposes you can also compare EEGNet performance to using 
 Riemannian geometric approaches with xDAWN spatial filtering [4-8] using 
 PyRiemann (code provided below).

 [1] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck,
     L. Parkkonen, M. Hämäläinen, MNE software for processing MEG and EEG data, 
     NeuroImage, Volume 86, 1 February 2014, Pages 446-460, ISSN 1053-8119.

 [2] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, 
     R. Goj, M. Jas, T. Brooks, L. Parkkonen, M. Hämäläinen, MEG and EEG data 
     analysis with MNE-Python, Frontiers in Neuroscience, Volume 7, 2013.

 [3] https://github.com/alexandrebarachant/pyRiemann. 

 [4] A. Barachant, M. Congedo ,"A Plug&Play P300 BCI Using Information Geometry"
     arXiv:1409.0107. link

 [5] M. Congedo, A. Barachant, A. Andreev ,"A New generation of Brain-Computer 
     Interface Based on Riemannian Geometry", arXiv: 1310.8115.

 [6] A. Barachant and S. Bonnet, "Channel selection procedure using riemannian 
     distance for BCI applications," in 2011 5th International IEEE/EMBS 
     Conference on Neural Engineering (NER), 2011, 348-351.

 [7] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Multiclass 
     Brain-Computer Interface Classification by Riemannian Geometry,” in IEEE 
     Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.

 [8] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Classification of 
     covariance matrices using a Riemannian-based kernel for BCI applications“, 
     in NeuroComputing, vol. 112, p. 172-178, 2013.


 Portions of this project are works of the United States Government and are not
 subject to domestic copyright protection under 17 USC Sec. 105.  Those 
 portions are released world-wide under the terms of the Creative Commons Zero 
 1.0 (CC0) license.  
 
 Other portions of this project are subject to domestic copyright protection 
 under 17 USC Sec. 105.  Those portions are licensed under the Apache 2.0 
 license.  The complete text of the license governing this material is in 
 the file labeled LICENSE.TXT that is a part of this project's official 
 distribution. 


In [1]:
import numpy as np
import pickle

# mne imports
import mne
from mne import io
from mne.datasets import sample

# EEGNet-specific imports
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping

# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression

# tools for plotting confusion matrices
from matplotlib import pyplot as plt

In [2]:
# while the default tensorflow ordering is 'channels_last' we set it here
# to be explicit in case if the user has changed the default ordering
K.set_image_data_format('channels_last')

In [5]:
# Preprocessing & results----------------
from sklearn.model_selection import train_test_split, cross_validate, cross_val_score, GridSearchCV
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Models-------------------------
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression, RidgeClassifier, SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.gaussian_process import GaussianProcessClassifier
import sklearn.gaussian_process.kernels as kls
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier, BaggingClassifier, ExtraTreesClassifier

# General purpose
import re
import pandas as pd
import pickle
import numpy as np
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

In [8]:
def load_dataset(user_split = True):
    X_train = None
    y_train_medid = None
    y_train_subjid = None
    X_val = None
    y_val_medid = None
    y_val_subjid = None
    X_test = None
    y_test_medid = None
    y_test_subjid = None

    # user_split: Determines the creation of train/val/test set. In user_split, test/val users are never seen during the train. time_split randomly splits each user's chunks into train/val/test.
    if user_split:
        data_file_path = '../../data/Meditation/user_based_splits.pkl'
    else:
        data_file_path = '../../data/Meditation/time_based_splits.pkl'

    with open(data_file_path, 'rb') as f:
        all_data_splits = pickle.load(f)

        X_train = all_data_splits['train']['x']
        y_train_medid = all_data_splits['train']['y_med']
        y_train_subjid = all_data_splits['train']['y_subj']

        X_val = all_data_splits['val']['x']
        y_val_medid = all_data_splits['val']['y_med']
        y_val_subjid = all_data_splits['val']['y_subj']

        X_test = all_data_splits['test']['x']
        y_test_medid = all_data_splits['test']['y_med']
        y_test_subjid = all_data_splits['test']['y_subj']

    return X_train, y_train_medid, y_train_subjid, X_val, y_val_medid, y_val_subjid, X_test, y_test_medid, y_test_subjid

In [6]:
clf_dict = {
    'DecisionTree': {"model": DecisionTreeClassifier(random_state=42), "params": {'max_depth': list(range(2, 16, 3))}},
    'RandomForest': {"model": RandomForestClassifier(random_state=42),
                     "params": {'n_estimators': list(range(5, 100, 5)), 'max_depth': list(range(2, 16))}},
    'LogisticR_L1': {"model": LogisticRegression(random_state=42, max_iter=500),
                     "params": {'penalty': ['l1'], 'solver': ['liblinear', 'saga']}},
    'LogisticR_L2': {"model": LogisticRegression(random_state=42, max_iter=1000),
                     "params": {'penalty': ['l2'], 'solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']}},
    'LogisticR': {"model": LogisticRegression(random_state=42, max_iter=500),
                  "params": {'penalty': ['none'], 'solver': ['newton-cg', 'lbfgs', 'sag', 'saga']}},
    'RidgeClf': {"model": RidgeClassifier(max_iter=1000), "params": {}},
    'SVC_linear': {"model": SVC(random_state=42), "params": {'kernel': ['linear'], 
                                                             'C': [0.5, 1.0, 1.5, 2.0, 2.5]}},
    'SVC_poly': {"model": SVC(random_state=42),
                 "params": {'kernel': ['poly'], 'degree': [3, 4, 5], 'gamma': ['scale', 'auto'], 
                            'C': [0.5, 1.0, 1.5, 2.0, 2.5]}},
    'SVC_others': {"model": SVC(random_state=42), "params": {'kernel': ['rbf', 'sigmoid'], 
                                                             'gamma': ['scale', 'auto'], 
                                                             'C': [0.5, 1.0, 1.5, 2.0, 2.5]}},
    'GussianNB': {"model": GaussianNB(), "params": {}},
    'KNN': {"model": KNeighborsClassifier(), "params": {'n_neighbors': list(range(3, 30))}},
    'GaussianProcessClf': {"model": GaussianProcessClassifier(random_state=42, kernel=kls.RBF()), "params": {}},
    'Bagging_SVC': {"model": BaggingClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                            'base_estimator': [SVC(kernel='linear'),
                                                                                               SVC(kernel='poly',
                                                                                                   degree=3,
                                                                                                   gamma='scale')]}},
    'BaggingDT': {"model": BaggingClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                          'base_estimator': [
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=2),
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=5),
                                                                              DecisionTreeClassifier(random_state=42,
                                                                                                     max_depth=10)]}},
    'AdaBoost': {"model": AdaBoostClassifier(random_state=42), "params": {'n_estimators': list(range(5, 100, 20)),
                                                                          'base_estimator': [DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=2),
                                                                                             DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=5),
                                                                                             DecisionTreeClassifier(
                                                                                                 random_state=42,
                                                                                                 max_depth=10)]}},
    'ExtraTrees': {"model": ExtraTreesClassifier(random_state=42), "params": {'n_estimators': list(range(5, 105, 30)), 
                                                                              'max_depth': [2,5,10,15]}},
    'MLP_l1': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x,) for x in 
                                                                                          range(50, 600, 200)], 
                                                                  'activation': ['logistic', 'tanh', 'relu'],
                                                                  'solver': ['adam', 'sgd'], 'early_stopping': 
                                                                   [True]}},
    'MLP_l2': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x, y) for x in 
                                                                                          range(50, 600, 200) 
                                                                                          for y in range(50, 360, 200)], 
                                                                  'activation': ['logistic', 'tanh', 'relu'],
                                                                  'solver': ['adam', 'sgd'], 'early_stopping': 
                                                                                               [True]}},
#     'MLP_l3': {"model": MLPClassifier(random_state=42), "params": {'hidden_layer_sizes': [(x, y, z) for x in 
#                                                                                           range(50, 600, 100) 
#                                                                                           for y in range(50, 600, 100)
#                                                                                           for z in range(50, 360, 100)], 
#                                                                   'activation': ['logistic', 'tanh', 'relu'],
#                                                                   'solver': ['adam', 'sgd'], 'early_stopping': 
#                                                                                                [True]}},
    }


## 1. User Split

In [26]:
user_split   = True
med_tech_clf = True

X_train, y_train_medid, y_train_subjid, X_val, y_val_medid, y_val_subjid, X_test, y_test_medid, y_test_subjid = load_dataset(user_split=user_split)

In [29]:
# Taking mean across each channel

# X_train_list = []
# X_val_list = []
# X_test_list = []

# for i in range(X_train.shape[0]):
#     X_train_list.append(X_train[i,:,:].flatten())
# X_train_list = np.array(X_train_list)
# print(X_train.shape, X_train_list.shape)

# for i in range(X_val.shape[0]):
#     X_val_list.append(X_val[i,:,:].flatten())
# X_val_list = np.array(X_val_list)
# print(X_val.shape, X_val_list.shape)

# for i in range(X_test.shape[0]):
#     X_test_list.append(X_test[i,:,:].flatten())
# X_test_list = np.array(X_test_list)
# print(X_test.shape, X_test_list.shape)

In [30]:
X_train_list = []
X_val_list = []
X_test_list = []

for i in range(X_train.shape[0]):
    X_train_list.append(np.mean(X_train[i,:,:], axis=1))
X_train_list = np.array(X_train_list)
print(X_train.shape, X_train_list.shape)

for i in range(X_val.shape[0]):
    X_val_list.append(np.mean(X_val[i,:,:], axis=1))
X_val_list = np.array(X_val_list)
print(X_val.shape, X_val_list.shape)

for i in range(X_test.shape[0]):
    X_test_list.append(np.mean(X_test[i,:,:], axis=1))
X_test_list = np.array(X_test_list)
print(X_test.shape, X_test_list.shape)

(2057, 64, 2560) (2057, 64)
(631, 64, 2560) (631, 64)
(615, 64, 2560) (615, 64)


In [31]:
# X_train_df = pd.DataFrame(X_train_list, columns=[str(x) for x in range(0, 163840)])
# X_val_df = pd.DataFrame(X_val_list, columns=[str(x) for x in range(0, 163840)])
# X_test_df = pd.DataFrame(X_test_list, columns=[str(x) for x in range(0, 163840)])

# For keeping only 64 channels data
X_train_df = pd.DataFrame(X_train_list, columns=[str(x) for x in range(0, 64)])
X_val_df = pd.DataFrame(X_val_list, columns=[str(x) for x in range(0, 64)])
X_test_df = pd.DataFrame(X_test_list, columns=[str(x) for x in range(0, 64)])

In [33]:
# X_train, y_train_medid, y_train_subjid, X_val, y_val_medid, y_val_subjid, X_test, y_test_medid, y_test_subjid
model_results = pd.DataFrame()
model_results['Train_Accuracy'] = None
model_results['Val_Accuracy'] = None
model_results['Test_Accuracy'] = None
model_results['best_params'] = None


best_clf_ours = None
best_clf_val = 0

for clf_name, clf in clf_dict.items():
    classifier = GridSearchCV(clf['model'], clf['params'], n_jobs=10)
    classifier.fit(X_train_df, y_train_medid)
    best_model = classifier.best_estimator_
    
    y_predicted = best_model.predict(X_val_df)
    val_acc = accuracy_score(y_val_medid, y_predicted)
    
    print(clf_name, classifier.best_score_, classifier.best_params_, val_acc)
    
    if val_acc > best_clf_val:
        best_clf_val = val_acc
        best_clf_ours = best_model
        
    y_predicted = best_model.predict(X_test_df)
    test_acc = accuracy_score(y_test_medid, y_predicted)
    
    
    model_results.loc[clf_name, ['Train_Accuracy', 'Val_Accuracy', 'Test_Accuracy', 'best_params']] = [classifier.best_score_, val_acc, test_acc, classifier.best_params_]
    clsr = classification_report(y_test_medid, y_predicted)

print("================================================================================")
print(best_clf_ours)
best_y_hat = best_clf_ours.predict(X_test_df)
clsr = classification_report(y_test_medid, best_y_hat)
print(clsr)
test_acc = accuracy_score(y_test_medid, best_y_hat)
print("Test acc:", test_acc )
print("Weighted F1 score: ", f1_score(y_test_medid, best_y_hat, average='weighted'))

DecisionTree 0.9800722840337326 {'max_depth': 11} 0.3090332805071315
RandomForest 1.0 {'max_depth': 9, 'n_estimators': 65} 0.26465927099841524
LogisticR_L1 0.9995145631067961 {'penalty': 'l1', 'solver': 'liblinear'} 0.4183835182250396
LogisticR_L2 0.9995145631067961 {'penalty': 'l2', 'solver': 'liblinear'} 0.41362916006339145
LogisticR 0.9995145631067961 {'penalty': 'none', 'solver': 'newton-cg'} 0.3486529318541997
RidgeClf 1.0 {} 0.30269413629160064
SVC_linear 0.9995145631067961 {'C': 0.5, 'kernel': 'linear'} 0.4215530903328051
SVC_poly 0.998542508208726 {'C': 0.5, 'degree': 3, 'gamma': 'auto', 'kernel': 'poly'} 0.3090332805071315
SVC_others 1.0 {'C': 1.0, 'gamma': 'auto', 'kernel': 'rbf'} 0.15213946117274169
GussianNB 0.6698875581697493 {} 0.24722662440570523
KNN 1.0 {'n_neighbors': 3} 0.05705229793977813
GaussianProcessClf 1.0 {} 0.19968304278922344
Bagging_SVC 1.0 {'base_estimator': SVC(kernel='linear'), 'n_estimators': 25} 0.41996830427892234
BaggingDT 0.9985436893203883 {'base_es

In [34]:
model_results

Unnamed: 0,Train_Accuracy,Val_Accuracy,Test_Accuracy,best_params
DecisionTree,0.980072,0.309033,0.341463,{'max_depth': 11}
RandomForest,1.0,0.264659,0.35122,"{'max_depth': 9, 'n_estimators': 65}"
LogisticR_L1,0.999515,0.418384,0.239024,"{'penalty': 'l1', 'solver': 'liblinear'}"
LogisticR_L2,0.999515,0.413629,0.346341,"{'penalty': 'l2', 'solver': 'liblinear'}"
LogisticR,0.999515,0.348653,0.343089,"{'penalty': 'none', 'solver': 'newton-cg'}"
RidgeClf,1.0,0.302694,0.182114,{}
SVC_linear,0.999515,0.421553,0.343089,"{'C': 0.5, 'kernel': 'linear'}"
SVC_poly,0.998543,0.309033,0.186992,"{'C': 0.5, 'degree': 3, 'gamma': 'auto', 'kern..."
SVC_others,1.0,0.152139,0.162602,"{'C': 1.0, 'gamma': 'auto', 'kernel': 'rbf'}"
GussianNB,0.669888,0.247227,0.198374,{}


In [35]:
print("Done!!!")

Done!!!


## 2. Time Split

In [37]:
user_split   = False
med_tech_clf = True

X_train, y_train_medid, y_train_subjid, X_val, y_val_medid, y_val_subjid, X_test, y_test_medid, y_test_subjid = load_dataset(user_split=user_split)

In [38]:
X_train_list = []
X_val_list = []
X_test_list = []

for i in range(X_train.shape[0]):
    X_train_list.append(np.mean(X_train[i,:,:], axis=1))
X_train_list = np.array(X_train_list)
print(X_train.shape, X_train_list.shape)

for i in range(X_val.shape[0]):
    X_val_list.append(np.mean(X_val[i,:,:], axis=1))
X_val_list = np.array(X_val_list)
print(X_val.shape, X_val_list.shape)

for i in range(X_test.shape[0]):
    X_test_list.append(np.mean(X_test[i,:,:], axis=1))
X_test_list = np.array(X_test_list)
print(X_test.shape, X_test_list.shape)

(1954, 64, 2560) (1954, 64)
(661, 64, 2560) (661, 64)
(688, 64, 2560) (688, 64)


In [39]:
# For keeping only 64 channels data
X_train_df = pd.DataFrame(X_train_list, columns=[str(x) for x in range(0, 64)])
X_val_df = pd.DataFrame(X_val_list, columns=[str(x) for x in range(0, 64)])
X_test_df = pd.DataFrame(X_test_list, columns=[str(x) for x in range(0, 64)])

In [40]:
# X_train, y_train_medid, y_train_subjid, X_val, y_val_medid, y_val_subjid, X_test, y_test_medid, y_test_subjid
model_results_timesplit = pd.DataFrame()
model_results_timesplit['Train_Accuracy'] = None
model_results_timesplit['Val_Accuracy'] = None
model_results_timesplit['Test_Accuracy'] = None
model_results_timesplit['best_params'] = None


best_clf_ours = None
best_clf_val = 0

for clf_name, clf in clf_dict.items():
    classifier = GridSearchCV(clf['model'], clf['params'], n_jobs=10)
    classifier.fit(X_train_df, y_train_medid)
    best_model = classifier.best_estimator_
    
    y_predicted = best_model.predict(X_val_df)
    val_acc = accuracy_score(y_val_medid, y_predicted)
    
    print(clf_name, classifier.best_score_, classifier.best_params_, val_acc)
    
    if val_acc > best_clf_val:
        best_clf_val = val_acc
        best_clf_ours = best_model
        
    y_predicted = best_model.predict(X_test_df)
    test_acc = accuracy_score(y_test_medid, y_predicted)
    
    
    model_results_timesplit.loc[clf_name, ['Train_Accuracy', 'Val_Accuracy', 'Test_Accuracy', 'best_params']] = [classifier.best_score_, val_acc, test_acc, classifier.best_params_]
    clsr = classification_report(y_test_medid, y_predicted)

print("================================================================================")
print(best_clf_ours)
best_y_hat = best_clf_ours.predict(X_test_df)
clsr = classification_report(y_test_medid, best_y_hat)
print(clsr)
test_acc = accuracy_score(y_test_medid, best_y_hat)
print("Test acc:", test_acc )
print("Weighted F1 score: ", f1_score(y_test_medid, best_y_hat, average='weighted'))

DecisionTree 0.9703206767656896 {'max_depth': 14} 0.972768532526475
RandomForest 1.0 {'max_depth': 8, 'n_estimators': 65} 0.9969742813918305
LogisticR_L1 0.9984641615843662 {'penalty': 'l1', 'solver': 'saga'} 0.9954614220877458
LogisticR_L2 0.9989756705357727 {'penalty': 'l2', 'solver': 'newton-cg'} 0.9954614220877458
LogisticR 0.9989756705357727 {'penalty': 'none', 'solver': 'sag'} 0.9954614220877458
RidgeClf 0.9943734015345269 {} 0.9954614220877458
SVC_linear 0.9989756705357727 {'C': 0.5, 'kernel': 'linear'} 0.9969742813918305
SVC_poly 0.9969309462915602 {'C': 1.0, 'degree': 3, 'gamma': 'auto', 'kernel': 'poly'} 0.9954614220877458
SVC_others 0.9994884910485933 {'C': 0.5, 'gamma': 'scale', 'kernel': 'rbf'} 0.9984871406959153
GussianNB 0.5066614204210111 {} 0.5279878971255674
KNN 1.0 {'n_neighbors': 3} 1.0
GaussianProcessClf 0.9989769820971868 {} 1.0
Bagging_SVC 0.9989756705357727 {'base_estimator': SVC(kernel='linear'), 'n_estimators': 5} 0.9969742813918305
BaggingDT 0.993859269460292

In [41]:
model_results_timesplit

Unnamed: 0,Train_Accuracy,Val_Accuracy,Test_Accuracy,best_params
DecisionTree,0.970321,0.972769,0.984012,{'max_depth': 14}
RandomForest,1.0,0.996974,0.998547,"{'max_depth': 8, 'n_estimators': 65}"
LogisticR_L1,0.998464,0.995461,0.99564,"{'penalty': 'l1', 'solver': 'saga'}"
LogisticR_L2,0.998976,0.995461,0.997093,"{'penalty': 'l2', 'solver': 'newton-cg'}"
LogisticR,0.998976,0.995461,0.997093,"{'penalty': 'none', 'solver': 'sag'}"
RidgeClf,0.994373,0.995461,0.994186,{}
SVC_linear,0.998976,0.996974,0.99564,"{'C': 0.5, 'kernel': 'linear'}"
SVC_poly,0.996931,0.995461,0.991279,"{'C': 1.0, 'degree': 3, 'gamma': 'auto', 'kern..."
SVC_others,0.999488,0.998487,0.997093,"{'C': 0.5, 'gamma': 'scale', 'kernel': 'rbf'}"
GussianNB,0.506661,0.527988,0.530523,{}
