In [1]:
import numpy as np
import pandas as pd
import pickle


# EEGNet-specific imports
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping

# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix

# # sklearn imports
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
# from sklearn.metrics import classification_report, accuracy_score, f1_score, balanced_accuracy_score

# # tools for plotting confusion matrices
# import matplotlib
# from matplotlib import pyplot as plt
# import seaborn as sn

In [2]:
def load_dataset_ctrvsone_and_transfer(user_split = True, runidx=1, data_type='meditation', one_class='HT'):
    X_train = None
    y_train_medid = None
    y_train_subjid = None
    X_val = None
    y_val_medid = None
    y_val_subjid = None
    X_test = None
    y_test_medid = None
    y_test_subjid = None

    dtype_names = ["HT", "SNY", "VIP", "CTR"]

    # user_split: Determines the creation of train/val/test set. In user_split, test/val users are never seen during the train. time_split randomly splits each user's chunks into train/val/test.
    if user_split:
        data_file_path = '../../iconip_data/{}/user_based_splits_with_timestamp_RUN{}.pkl'.format(data_type, runidx)
    else:
        data_file_path = '../../iconip_data/{}/time_based_splits_with_timestamp_RUN{}.pkl'.format(data_type, runidx)

    with open(data_file_path, 'rb') as f:
        all_data_splits = pickle.load(f)

        data_index_label_oneclass = dtype_names.index(one_class)
        data_index_label_ctr = dtype_names.index('CTR')
        transfer_index_labels = [x for x in list(set([0,1,2,3]).difference(set([dtype_names.index(one_class), 3]))) ]
        
        X_train = all_data_splits['train']['x']
        y_train_medid = all_data_splits['train']['y_med']
        y_train_subjid = all_data_splits['train']['y_subj']
        y_train_ts = all_data_splits['train']['y_ts']
        clf_config_subset_indices = np.where((y_train_medid == data_index_label_oneclass) | (y_train_medid == data_index_label_ctr))[0]
        X_train = X_train[(clf_config_subset_indices)]
        y_train_medid = y_train_medid[(clf_config_subset_indices)]
        y_train_subjid = y_train_subjid[clf_config_subset_indices]
        y_train_ts = y_train_ts[(clf_config_subset_indices)]

        X_val = all_data_splits['val']['x']
        y_val_medid = all_data_splits['val']['y_med']
        y_val_subjid = all_data_splits['val']['y_subj']
        y_val_ts = all_data_splits['val']['y_ts']
        clf_config_subset_indices = np.where((y_val_medid == data_index_label_oneclass) | (y_val_medid == data_index_label_ctr))[0]
        X_val = X_val[(clf_config_subset_indices)]
        y_val_medid = y_val_medid[(clf_config_subset_indices)]
        y_val_subjid = y_val_subjid[(clf_config_subset_indices)]
        y_val_ts = y_val_ts[(clf_config_subset_indices)]
        
        # Pretrain test and transfer test set
        X_test = all_data_splits['test']['x']
        y_test_medid = all_data_splits['test']['y_med']
        y_test_subjid = all_data_splits['test']['y_subj']
        y_test_ts = all_data_splits['test']['y_ts']
        
        # transfer test contains CTR from test and other two classes from the test.
        ctr_othertwo_indices = np.where((y_test_medid == data_index_label_ctr) | (y_test_medid == transfer_index_labels[0]) | (y_test_medid == transfer_index_labels[1]))[0]
        X_test_transfer = X_test[(ctr_othertwo_indices)]
        y_test_transfer_medid = y_test_medid[(ctr_othertwo_indices)]
        y_test_transfer_subjid = y_test_subjid[(ctr_othertwo_indices)]
        
        # pretrain test filter CTR and one_class from the test.
        clf_config_subset_indices = np.where((y_test_medid == data_index_label_oneclass) | (y_test_medid == data_index_label_ctr))[0]
        X_test = X_test[(clf_config_subset_indices)]
        y_test_medid = y_test_medid[(clf_config_subset_indices)]
        y_test_subjid = y_test_subjid[(clf_config_subset_indices)]
        y_test_ts = y_test_ts[(clf_config_subset_indices)]

    return X_train, y_train_medid, y_train_subjid, y_train_ts, X_val, y_val_medid, y_val_subjid, y_val_ts, X_test, y_test_medid, y_test_subjid, y_test_ts, X_test_transfer, y_test_transfer_medid, y_test_transfer_subjid

In [3]:
def train_model(user_split, med_tech_clf, data_type, one_class):
    
    kernels, chans, samples = 1, 64, 2560
    
    split_strat = 'user' if user_split else 'time'
    model_results = pd.DataFrame()

    for runidx in range(1, 6):
        model_results['{}_Train_Acc'.format(runidx)] = None
        model_results['{}_Val_Acc'.format(runidx)] = None
        model_results['{}_Test_Acc'.format(runidx)] = None
        model_results['{}_Test_Transfer_Acc'.format(runidx)] = None
        model_results['{}_best_params'.format(runidx)] = None


    for runidx in range(1, 6):
        print("RUN ID: ", runidx)
        med_tech_clf = True
        
        X_train, y_train_medid, y_train_subjid, y_train_ts, X_val, y_val_medid, y_val_subjid, y_val_ts, X_test, y_test_medid, y_test_subjid, y_test_ts, X_transfer, y_test_transfer_medid, y_test_transfer_subjid = load_dataset_ctrvsone_and_transfer(user_split=user_split, runidx=runidx, data_type=data_type, one_class=one_class)
        ## Convert data to binary classification task, i.e. meditation expert as class 0 vs control group as class 1
        y_train_medid = y_train_medid//3
        y_val_medid = y_val_medid//3
        y_test_medid = y_test_medid//3
        y_test_transfer_medid = y_test_transfer_medid//3

        best_clf_ours = None
        best_clf_val = 0

        #X_test_transfer, y_test_transfer_medid, y_test_transfer_subjid
        if med_tech_clf:
            Y_train      = y_train_medid
            Y_validate   = y_val_medid
            Y_test       = y_test_medid
            Y_transfer   = y_test_transfer_medid
            ckpt_file_suffix = '{}_med_clf'.format(split_strat)
        else:
            Y_train      = y_train_subjid
            Y_validate   = y_val_subjid
            Y_test       = y_test_subjid
            Y_transfer   = y_test_transfer_subjid
            ckpt_file_suffix = '{}_subj_clf'.format(split_strat)
        nb_classes = np.unique(Y_train).shape[0]
        print(np.unique(Y_train), ": unique class labels.")
    
        ############################# EEGNet portion ##################################
        # convert labels to one-hot encodings.
        Y_train      = np_utils.to_categorical(Y_train)
        Y_validate   = np_utils.to_categorical(Y_validate)
        Y_test       = np_utils.to_categorical(Y_test)
        Y_transfer   = np_utils.to_categorical(Y_transfer)

        # convert data to NHWC (trials, channels, samples, kernels) format. Data 
        # contains 60 channels and 151 time-points. Set the number of kernels to 1.
        X_train      = X_train.reshape(X_train.shape[0], chans, samples, kernels)
        X_val        = X_val.reshape(X_val.shape[0], chans, samples, kernels)
        X_test       = X_test.reshape(X_test.shape[0], chans, samples, kernels)
        X_transfer   = X_transfer.reshape(X_transfer.shape[0], chans, samples, kernels)

        print('X_train shape:', X_train.shape)
        print(X_train.shape[0], 'train samples')
        print(X_test.shape[0], 'test samples')
        print(X_transfer.shape[0], 'transfer samples')
        print(nb_classes, " number of classes")
        
        ############################# PyRiemann Portion ##############################
        # code is taken from PyRiemann's ERP sample script, which is decoding in 
        # the tangent space with a logistic regression
        n_components = 2  # pick some components

        # set up sklearn pipeline
        clf = make_pipeline(XdawnCovariances(n_components),
                            TangentSpace(metric='riemann'),
                            LogisticRegression())
        preds_rg     = np.zeros(len(Y_test))

        # reshape back to (trials, channels, samples)
        X_train      = X_train.reshape(X_train.shape[0], chans, samples)
        X_val        = X_val.reshape(X_val.shape[0], chans, samples)
        X_test       = X_test.reshape(X_test.shape[0], chans, samples)
        X_transfer   = X_transfer.reshape(X_transfer.shape[0], chans, samples)

        # train a classifier with xDAWN spatial filtering + Riemannian Geometry (RG)
        # labels need to be back in single-column format
        clf.fit(X_train, Y_train.argmax(axis = -1))
        preds_rg_test      = clf.predict(X_test)
        preds_rg_transfer  = clf.predict(X_transfer)

        # save preds on train and val set for error analysis
        preds_train_rg = clf.predict(X_train)
        preds_train_rg = preds_train_rg.argmax(axis = -1)
        preds_val_rg = clf.predict(X_val)
        preds_val_rg = preds_val_rg.argmax(axis = -1)

        # Printing the results
        acc2_test         = np.mean(preds_rg_test == Y_test.argmax(axis = -1))
        acc2_transfer     = np.mean(preds_rg_transfer == Y_transfer.argmax(axis = -1))
#         print(preds_rg_test.shape, Y_test.argmax(axis = -1).shape)
#         print(preds_rg_transfer.shape, Y_transfer.argmax(axis = -1).shape)
        print("PyRiemann Classification accuracy: %f %f" % (acc2_test, acc2_transfer))
    
        model_results.loc["PyRiemann", ['{}_Train_Acc'.format(runidx), '{}_Val_Acc'.format(runidx), '{}_Test_Acc'.format(runidx), '{}_Test_Transfer_Acc'.format(runidx)]] = [preds_train_rg, preds_val_rg, acc2_test, acc2_transfer]
    if user_split:
        model_results.to_csv('./RESULTS_T5{}_MED_TransferPRED_{}_CTRvs{}_PyRiemann.csv'.format('USER', data_type, one_class))
    else:
        model_results.to_csv('./RESULTS_T5_MED_TransferPRED_{}_CTRvs{}_PyRiemann.csv'.format(data_type, one_class))

    return model_results

In [4]:
# time split
for one_class in ['HT', 'SNY', 'VIP']:
    model_results = train_model(user_split=False, med_tech_clf=True, data_type='meditation', one_class=one_class)

RUN ID:  1
[0 1] : unique class labels.
X_train shape: (990, 64, 2560, 1)
990 train samples
344 test samples
523 transfer samples
2  number of classes
PyRiemann Classification accuracy: 0.741279 0.527725
RUN ID:  2
[0 1] : unique class labels.
X_train shape: (990, 64, 2560, 1)
990 train samples
344 test samples
523 transfer samples
2  number of classes
PyRiemann Classification accuracy: 0.715116 0.558317
RUN ID:  3
[0 1] : unique class labels.
X_train shape: (990, 64, 2560, 1)
990 train samples
344 test samples
523 transfer samples
2  number of classes
PyRiemann Classification accuracy: 0.726744 0.554493
RUN ID:  4
[0 1] : unique class labels.
X_train shape: (990, 64, 2560, 1)
990 train samples
344 test samples
523 transfer samples
2  number of classes
PyRiemann Classification accuracy: 0.738372 0.590822
RUN ID:  5
[0 1] : unique class labels.
X_train shape: (990, 64, 2560, 1)
990 train samples
344 test samples
523 transfer samples
2  number of classes
PyRiemann Classification accuracy

In [22]:
# User split
for one_class in ['HT', 'SNY', 'VIP']:
    model_results = train_model(user_split=True, med_tech_clf=True, data_type='meditation', one_class=one_class)

RUN ID:  1
[0 1] : unique class labels.
X_train shape: (1038, 64, 2560, 1)
1038 train samples
308 test samples
447 transfer samples
2  number of classes
(308,) (308,)
(447,) (447,)
PyRiemann Classification accuracy: 0.126623 0.384787
RUN ID:  2
[0 1] : unique class labels.
X_train shape: (1031, 64, 2560, 1)
1031 train samples
325 test samples
460 transfer samples
2  number of classes
(325,) (325,)
(460,) (460,)
PyRiemann Classification accuracy: 0.643077 0.671739
RUN ID:  3
[0 1] : unique class labels.
X_train shape: (1058, 64, 2560, 1)
1058 train samples
311 test samples
485 transfer samples
2  number of classes
(311,) (311,)
(485,) (485,)
PyRiemann Classification accuracy: 0.305466 0.455670
RUN ID:  4
[0 1] : unique class labels.
X_train shape: (1057, 64, 2560, 1)
1057 train samples
301 test samples
475 transfer samples
2  number of classes
(301,) (301,)
(475,) (475,)
PyRiemann Classification accuracy: 0.388704 0.557895
RUN ID:  5
[0 1] : unique class labels.
X_train shape: (1023, 64