> In this notebook a number of models are trained on the UCI data. Required libraries and the UCI dataset are first imported.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import LinearSVC, SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, auc, confusion_matrix, f1_score, classification_report
from sklearn.metrics import make_scorer, precision_score, recall_score
from scipy.special import expit, logit

In [2]:
df = pd.read_csv('EEG_UCI_dataset_powers.csv')
df

Unnamed: 0.1,Unnamed: 0,Fp1a delta,Fp1a theta,Fp1a alpha,Fp1a beta,Fp1a gamma,Fp2a delta,Fp2a theta,Fp2a alpha,Fp2a beta,...,P3/P4 delta,P3/P4 theta,P3/P4 alpha,P3/P4 beta,P3/P4 gamma,O1/O2 delta,O1/O2 theta,O1/O2 alpha,O1/O2 beta,O1/O2 gamma
0,co2a0000364,34.951542,6.329759,1.577445,5.040983,2.695547,30.629374,6.105156,2.187954,13.458793,...,0.459534,0.554567,0.547984,0.579646,0.637105,0.518184,0.540152,0.546512,0.580773,0.634882
1,co2a0000365,9.299793,2.568554,8.272992,10.244713,4.960425,7.650179,2.816241,7.805378,6.147001,...,0.481977,0.456597,0.477621,0.494127,0.489962,0.465665,0.490152,0.462667,0.486200,0.528533
2,co2a0000368,3.515035,1.541957,5.433650,2.411241,0.941431,3.768631,1.358769,5.551373,2.415200,...,0.458276,0.450038,0.563601,0.481220,0.480997,0.438477,0.458350,0.504987,0.474034,0.360506
3,co2a0000369,4.795953,4.360560,18.934703,3.818618,0.956625,4.472783,4.256171,18.158179,3.582341,...,0.563004,0.503810,0.528841,0.496117,0.521641,0.530270,0.528711,0.517796,0.546580,0.520409
4,co2a0000370,3.736512,1.392974,5.378064,4.374204,1.306564,4.511004,1.353615,5.034374,5.871832,...,0.414734,0.428938,0.371896,0.451796,0.560019,0.467281,0.505975,0.513976,0.594723,0.724742
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
117,co3a0000458,5.034276,1.297364,2.776220,7.503791,1.389794,6.732382,1.791624,3.072971,8.030404,...,0.488738,0.554300,0.502656,0.517526,0.541235,0.461837,0.480084,0.494039,0.503561,0.577265
118,co3a0000459,4.925952,1.284041,3.818269,4.602541,1.256179,4.730081,1.297648,3.812826,8.845018,...,0.521518,0.476234,0.598763,0.568120,0.557745,0.508269,0.549875,0.526056,0.553384,0.539971
119,co3a0000460,5.413555,2.577622,2.853991,3.192140,0.737959,5.674537,2.723768,2.885980,3.208463,...,0.522837,0.545288,0.595220,0.539650,0.551718,0.521833,0.565657,0.516360,0.473023,0.439638
120,co3a0000461,7.399629,1.507018,1.856139,2.528717,0.567038,5.505387,1.312715,2.250694,2.524480,...,0.473300,0.605074,0.520692,0.566241,0.594854,0.525616,0.526378,0.529584,0.537405,0.588164


In [3]:
df = df.rename(columns={'Unnamed: 0': 'subject'})
df['status'] = (df['subject'].str.slice(start=3, stop=4) == "a").astype(int)
df.tail()

Unnamed: 0,subject,Fp1a delta,Fp1a theta,Fp1a alpha,Fp1a beta,Fp1a gamma,Fp2a delta,Fp2a theta,Fp2a alpha,Fp2a beta,...,P3/P4 theta,P3/P4 alpha,P3/P4 beta,P3/P4 gamma,O1/O2 delta,O1/O2 theta,O1/O2 alpha,O1/O2 beta,O1/O2 gamma,status
117,co3a0000458,5.034276,1.297364,2.77622,7.503791,1.389794,6.732382,1.791624,3.072971,8.030404,...,0.5543,0.502656,0.517526,0.541235,0.461837,0.480084,0.494039,0.503561,0.577265,1
118,co3a0000459,4.925952,1.284041,3.818269,4.602541,1.256179,4.730081,1.297648,3.812826,8.845018,...,0.476234,0.598763,0.56812,0.557745,0.508269,0.549875,0.526056,0.553384,0.539971,1
119,co3a0000460,5.413555,2.577622,2.853991,3.19214,0.737959,5.674537,2.723768,2.88598,3.208463,...,0.545288,0.59522,0.53965,0.551718,0.521833,0.565657,0.51636,0.473023,0.439638,1
120,co3a0000461,7.399629,1.507018,1.856139,2.528717,0.567038,5.505387,1.312715,2.250694,2.52448,...,0.605074,0.520692,0.566241,0.594854,0.525616,0.526378,0.529584,0.537405,0.588164,1
121,co3c0000402,5.284542,1.946691,1.135617,2.015422,0.492179,5.092513,1.986647,1.3372,2.166617,...,0.497312,0.492201,0.49275,0.493309,0.498716,0.516931,0.544294,0.517307,0.497229,0


In [4]:
print("Alcoholic subjects:", len(df.status.loc[df.status ==1]))
print("Control subjects:", len(df.status.loc[df.status == 0]))
print("Proportion of alcoholic subjects:", round(len(df.status.loc[df.status == 1]) / (len(df.status.loc[df.status == 1]) + len(df.status.loc[df.status == 0])), 3))

Alcoholic subjects: 77
Control subjects: 45
Proportion of alcoholic subjects: 0.631


> The dataset is then split into training and test sets and pipelines are prepared.

In [5]:
y = df.status
X = df.drop(['subject', 'status'], axis=1)

In [6]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1234, stratify=df.status)

In [7]:
print(len(X_train), len(X_test), len(y_train), len(y_test))

91 31 91 31


In [8]:
pipelines = {
    'l1': make_pipeline(StandardScaler(), LogisticRegression(penalty='l1', solver='liblinear', random_state=123)),
    'l2': make_pipeline(StandardScaler(), LogisticRegression(penalty='l2', random_state=123)),
    'rf': make_pipeline(StandardScaler(), RandomForestClassifier(random_state=123)),
    'gb': make_pipeline(StandardScaler(), GradientBoostingClassifier(random_state=123)),
    'svm_linear': make_pipeline(StandardScaler(), LinearSVC(random_state=123)),
    'svm_rbf': make_pipeline(StandardScaler(), SVC(kernel='rbf', random_state=123))
}

In [9]:
l1_hyperparameters = {'logisticregression__C': np.logspace(-3, 3, 13)}
l2_hyperparameters = {'logisticregression__C': np.logspace(-3, 3, 13)}
rf_hyperparameters = {
    'randomforestclassifier__n_estimators': [25, 50, 75, 100, 150, 200],
    'randomforestclassifier__max_features': [None, 'sqrt',0.2, 0.33]
}
gb_hyperparameters = {
    'gradientboostingclassifier__n_estimators': [25, 50, 75, 100, 150, 200],
    'gradientboostingclassifier__learning_rate': [0.05, 0.1, 0.2],
    'gradientboostingclassifier__max_depth': [1, 3, 5]
}
svm_linear_hyperparameters = {'linearsvc__C': [0.1, 1, 10, 100]}
svm_rbf_hyperparameters = {
    'svc__C': [0.1, 1, 10, 100],
    'svc__gamma': [0.0001, 0.001, 0.01, 0.1, 1, 10]
}

In [10]:
hyperparameters = {
    'l1': l1_hyperparameters,
    'l2': l2_hyperparameters,
    'rf': rf_hyperparameters,
    'gb': gb_hyperparameters,
    'svm_linear': svm_linear_hyperparameters,
    'svm_rbf': svm_rbf_hyperparameters
}

> The models are then trained and fitted using 3-fold cross validation to find optimal hyperparameters. Accuracy scores are shown, obtained from the best cross-validated models using the training data.

In [11]:
fitted_models = {}
for i in ['l1', 'l2', 'rf', 'gb', 'svm_linear', 'svm_rbf']:
    model = GridSearchCV(pipelines[i], hyperparameters[i], cv=3, scoring='f1', n_jobs=-1)
    model.fit(X_train, y_train)
    fitted_models[i] = model
    print(i, 'has been fitted')

l1 has been fitted
l2 has been fitted
rf has been fitted
gb has been fitted
svm_linear has been fitted
svm_rbf has been fitted


In [12]:
for name, model in fitted_models.items():
    print(name, model.best_score_)

l1 0.7722945617682461
l2 0.8184637068357999
rf 0.8032520325203253
gb 0.7457264957264957
svm_linear 0.7676623992413466
svm_rbf 0.8092352092352092


> The models are then used to make predictions on the test dataset, with metrics displayed in detail below. The confusion matrix in each case is formatted as follows:  
> 
> &nbsp; &nbsp; \[[ true negatives | false positives ]  
> &nbsp; &nbsp; &nbsp;[ false negatives | true positives ]]  
> 
> **L2-regularised logistic regression** is the winning model for both accuracy and (class 1) F1 score.

In [13]:
print("Scores for each model:\n")
for name, model in fitted_models.items():
    if name in ['l1', 'l2', 'rf', 'gb']:
        pred = model.predict_proba(X_test)
        pred = pred[:, 1]
        y_pred = [int(p > 0.5) for p in pred]
    else:
        pred = model.decision_function(X_test)
        y_pred = [int(p > 0) for p in pred]
    fpr, tpr, thresholds = roc_curve(y_test, pred)
    print(name, "- AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print()

Scores for each model:

l1 - AUROC: 0.7727272727272727  F1 score: 0.7058823529411764
[[ 9  2]
 [ 8 12]]
              precision    recall  f1-score   support

           0       0.53      0.82      0.64        11
           1       0.86      0.60      0.71        20

    accuracy                           0.68        31
   macro avg       0.69      0.71      0.67        31
weighted avg       0.74      0.68      0.68        31


l2 - AUROC: 0.7727272727272728  F1 score: 0.8292682926829269
[[ 7  4]
 [ 3 17]]
              precision    recall  f1-score   support

           0       0.70      0.64      0.67        11
           1       0.81      0.85      0.83        20

    accuracy                           0.77        31
   macro avg       0.75      0.74      0.75        31
weighted avg       0.77      0.77      0.77        31


rf - AUROC: 0.7840909090909091  F1 score: 0.7567567567567567
[[ 8  3]
 [ 6 14]]
              precision    recall  f1-score   support

           0       0.57  

> The next cells use **nested cross validation**. This was tried because of the relatively small size of the dataset. In this process, there are 4 outer loops of validation. Within each outer loop, an inner loop of 3-fold cross validation is performed as before, and the best model for each algorithm is tested on a hold-out set. Because nested cross validation does not produce a single best fitted model, the aim is to find the best *process* from amongst the various algorithms tested, based on average metrics across the 4 outer loops.
> 
> Additionally here, **threshold optimisation** is applied. In an attempt to improve upon the default threshold of 0.5 for the prediction probability (or 0 for decision function, which becomes 0.5 when the logistic function is applied), thresholds from 0.01 to 0.99 are tested on the models to find the best average F1 score across the 3 inner loops of cross validation. Note that a different best threshold might be produced for each outer loop of validation, so the average of these 4 results, obtained by testing the best model on the hold-out set in each case, is shown in the summary at the end of this cell's output.

In [14]:
summary = "Summary of best models:\n\n"
for i in ['l1', 'l2', 'rf', 'gb', 'svm_linear', 'svm_rbf']:
    print("****", i.upper(), "****\n")
    cv_outer = StratifiedKFold(n_splits=4, shuffle=True, random_state=1)
    outer_results_def_thr = list()
    outer_results_best_thr = list()
    for train_ix, test_ix in cv_outer.split(X, y):
        X_train, X_test = X.iloc[train_ix, :], X.iloc[test_ix, :]
        y_train, y_test = y.iloc[train_ix], y.iloc[test_ix]
        cv_inner = StratifiedKFold(n_splits=3, shuffle=True, random_state=1)
        clf = GridSearchCV(pipelines[i], hyperparameters[i], cv=cv_inner, scoring='f1', n_jobs=-1)
        clf.fit(X_train, y_train)
        f1_array = np.zeros(99)
        for itrain_ix, itest_ix in cv_inner.split(X_train, y_train):
            iX_train, iX_test = X_train.iloc[itrain_ix, :], X_train.iloc[itest_ix, :]
            iy_train, iy_test = y_train.iloc[itrain_ix], y_train.iloc[itest_ix]
            model = pipelines[i]
            model.set_params(**clf.best_params_)
            model.fit(iX_train, iy_train)
            if i in ['l1', 'l2', 'rf', 'gb']:
                pred = model.predict_proba(iX_test)
                pred = pred[:, 1]
            else:
                pred = model.decision_function(iX_test)
            for thr in range(1, 100):
                if i in ['l1', 'l2', 'rf', 'gb']:
                    y_pred = [int(p > thr / 100) for p in pred]
                else:
                    y_pred = [int(expit(p) > thr / 100) for p in pred]
                f1_array[int(thr - 1)] += f1_score(iy_test, y_pred)
        best_threshold = (np.argmax(f1_array) + 1) / 100
        print(f1_array, "\n")
        if i in ['svm_linear', 'svm_rbf']:
            best_threshold = logit(best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            pred = clf.predict_proba(X_test)
            pred = pred[:, 1]
            y_pred = [int(p > 0.5) for p in pred]
        else:
            pred = clf.decision_function(X_test)
            y_pred = [int(p > 0) for p in pred]
        print("Default threshold")
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_def_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
        print("Best threshold:", best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            y_pred = [int(p > best_threshold) for p in pred]
        else:
            y_pred = [int(expit(p) > best_threshold) for p in pred]
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_best_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
    print("Mean scores with default threshold:")
    for score in ['f1', 'AUROC', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_def_thr]).mean())
    print("\nMean scores with best thresholds:")
    for score in ['f1', 'AUROC', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_best_thr]).mean())
    print("\n")
    summary = summary + i + "\n  Default threshold\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_def_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_def_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_def_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_def_thr]).mean())[:6]
    summary = summary + "\n  Best thresholds\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_best_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_best_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_best_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_best_thr]).mean())[:6] + "\n"
print(summary) 

**** L1 ****

[2.37574955 2.39404223 2.41563467 2.41563467 2.41563467 2.41563467
 2.41563467 2.41563467 2.41563467 2.41563467 2.41563467 2.38358339
 2.38358339 2.38358339 2.38358339 2.38358339 2.31428332 2.31428332
 2.31428332 2.31428332 2.28298887 2.28298887 2.28298887 2.28298887
 2.28298887 2.30551139 2.30551139 2.30551139 2.32932092 2.32932092
 2.34733894 2.34733894 2.34733894 2.34733894 2.34733894 2.31372549
 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549
 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549
 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549
 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549
 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549 2.31372549
 2.31372549 2.31372549 2.23997963 2.19964349 2.19964349 2.19964349
 2.19964349 2.19964349 2.19964349 2.19964349 2.19964349 2.19964349
 2.19964349 2.21746881 2.21746881 2.21746881 2.21746881 2.17959002
 2.13602941 2.13602941 2.09570683 2.11385199 2.1

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.2955102  2.2955102  2.31102041
 2.31102041 2.32717687 2.32717687 2.35884354 2.39184397 2.42626426
 2.44462175 2.46146572 2.43556378 2.43556378 2.43556378 2.43556378
 2.43556378 2.43556378 2.48972332 2.48322981 2.47608696 2.47608696
 2.47608696 2.4647816  2.50216657 2.50216657 2.52135849 2.52135849
 2.52135849 2.52135849 2.56219796 2.56219796 2.56219796 2.55570446
 2.49661654 2.49661654 2.49661654 2.49661654 2.46712936 2.4099865
 2.4099865  2.33679834 2.33679834 2.28986569 2.28986569 2.03921569
 1.97254902 1.97254902 1.97254902 1.97254902 1.99159664 1.99159664
 1.78864097 1.78864097 1.62122016 1.42122016 1.02122016 1.05698006
 0.51851852 0.51851852 0.48       0.48       0.         0.
 0.         0.         0.         0.         0.         0.
 0.         

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.30982393 2.32598039 2.32598039 2.34166667 2.34166667 2.35762411
 2.35762411 2.34978557 2.34978557 2.34978557 2.34978557 2.34978557
 2.34978557 2.34978557 2.32407635 2.29702321 2.29702321 2.33216533
 2.27982811 2.2510352  2.2510352  2.2510352  2.2510352  2.26699264
 2.28341776 2.28341776 2.25513494 2.25513494 2.25513494 2.25513494
 2.25513494 2.24689097 2.26331609 2.22204625 2.22204625 2.23946785
 2.23946785 2.23946785 2.23946785 2.20898004 2.17938174 2.17938174
 2.19599303 2.13292683 2.13292683 2.13292683 2.11794872 2.13589744
 2.08636977 2.04938542 2.01524602 1.9619883  1.90681858 1.90681858
 1.86928105 1.84705882 1.84705882 1.84705882 1.825      1.80147059
 1.75868984 1.62206745 1.57771261 1.57771261 1.52609971 1.48064516
 1.49677419 1.44731183 1.41149425 1.1968254  1.1968254  1.16296296
 1.16296296 1.17905492 1.17905492 1.12486772 1.13626374 1.1529304
 0.89551839 0.90782609 0.83782609 0.84869565 0.71541502 0.63636364
 0.46320346 0.46796537 0.38571429 0.38571429 0.38571429 0.38571

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.3355102  2.35166667 2.35166667
 2.34146765 2.31595745 2.31595745 2.33191489 2.34824142 2.36489267
 2.36489267 2.36489267 2.36489267 2.36489267 2.36489267 2.36489267
 2.36489267 2.34023281 2.34023281 2.35688406 2.35688406 2.35688406
 2.35688406 2.37372803 2.37372803 2.39111933 2.40869565 2.40869565
 2.38142292 2.39881423 2.37272727 2.37272727 2.37272727 2.37272727
 2.39090909 2.39090909 2.40909091 2.40909091 2.3255814  2.3255814
 2.3145072  2.3145072  2.2253194  2.26219512 2.24903723 2.21489782
 2.20392954 2.15643275 2.05128205 2.01159951 1.98562549 1.98505639
 2.00320155 2.00320155 1.96906215 1.98923022 1.98923022 1.86666667
 1.81609195 1.81609195 1.66560847 1.60376695 1.54108889 1.54108889
 1.50823529 1.38995943 1.29545455 1.25       1.25       1.26612903
 1.10703812 1.05757576 1.05757576 1.05757576 1.05757576 1.07609428
 0.96624066 0.96624066 0.92281792 0.84615385 0.84615385 0.84615385
 0.84615385 0.84615385 0.84615385 0.84615385 0.65       0.65
 0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.32717687 2.34333333 2.34333333 2.3601773  2.3601773  2.37702128
 2.39459759 2.39459759 2.39459759 2.39459759 2.41217391 2.38608696
 2.38608696 2.38608696 2.38608696 2.33350811 2.33350811 2.33350811
 2.32332437 2.29706174 2.28972868 2.3074474  2.32447217 2.32447217
 2.37790698 2.39713775 2.38357522 2.38357522 2.35453921 2.37445386
 2.39396606 2.39396606 2.41447888 2.41447888 2.41447888 2.41447888
 2.41447888 2.37994435 2.36336996 2.36336996 2.34432234 2.34432234
 2.34432234 2.31328321 2.31328321 2.28540864 2.30614035 2.27665317
 2.24561404 2.24561404 2.20394737 2.20394737 2.1268756  2.1268756
 2.14789662 2.1151797  2.08064516 2.04572453 1.9984127  1.96968673
 1.96968673 1.96968673 1.96968673 1.89452333 1.84033613 1.74113686
 1.66393098 1.62159227 1.62159227 1.59016129 1.49849462 1.49849462
 1.51688543 1.51688543 1.51688543 1.51688543 1.46516129 1.46516129
 1.46516129 1.36444444 1.30461538 1.30461538 1.25633952 1.27604396
 1.2231339  1.2231339  1.2231339  1.16615385 0.89548495 0.89548

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31102041 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041
 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041 2.32717687
 2.32717687 2.32717687 2.32717687 2.32717687 2.34333333 2.3601773
 2.39459759 2.39459759 2.42846529 2.42846529 2.40220266 2.40220266
 2.42123014 2.42123014 2.42123014 2.42123014 2.4411637  2.46206962
 2.46206962 2.43655942 2.43655942 2.43655942 2.43655942 2.43655942
 2.41065748 2.42661492 2.39978791 2.41111111 2.41111111 2.41724942
 2.44019139 2.45710471 2.47482342 2.46397893 2.46397893 2.46397893
 2.45238095 2.45238095 2.470964   2.44047619 2.39107951 2.27153492
 2.22773573 2.24846743 2.2135468  2.14708483 2.14708483 2.03516334
 2.02346743 2.04563492 2.01759259 2.01759259 1.96346154 1.9191067
 1.90698549 1.90698549 1.85244003 1.74666667 1.72137931 1.72137931
 1.60632184 1.53748126 1.45222169 1.27272727 1.06868132 0.95496894
 0.95496894 0.81102955 0.72877847 0.63830228 0.56320346 0.48095238
 0.38095238 0.29047619 0.2        0.1        0.1        0.1
 0.1

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.31102041 2.31102041 2.32653061 2.32653061 2.30102041
 2.30102041 2.30102041 2.31717687 2.33333333 2.3501773  2.36702128
 2.38459759 2.38459759 2.37669246 2.36767303 2.36767303 2.36767303
 2.36767303 2.35894843 2.39276486 2.39276486 2.36347193 2.3961039
 2.40499149 2.3757232  2.33020638 2.32240051 2.25643275 2.18412698
 1.88045541 1.87432796 1.80766129 1.7625     1.71422414 1.68409344
 1.6973545  1.59386973 1.19397993 0.86956522 0.71221532 0.55134576
 0.28181818 0.28181818 0.28181818 0.1952381  0.1952381  0.1952381
 0.         0.         0.         0.         0.         0.
 0.  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.3355102  2.35183673 2.35183673
 2.36734694 2.36734694 2.36734694 2.3835034  2.3835034  2.37568751
 2.37568751 2.39326383 2.40942029 2.38371107 2.34722222 2.31994949
 2.31994949 2.31994949 2.33679347 2.33679347 2.35396518 2.35396518
 2.37193559 2.3895119  2.3895119  2.3895119  2.40833804 2.42686203
 2.42686203 2.46495726 2.46495726 2.48414918 2.38685189 2.38685189
 2.3589924  2.32627547 2.27260982 2.25198413 2.21785714 2.17155646
 2.10877193 2.03670635 2.03670635 1.96481092 1.83034648 1.72820513
 1.59211823 1.45698006 1.00606061 0.76086957 0.51515152 0.44268775
 0.18181818 0.0952381  0.0952381  0.0952381  0.         0.
 0.         0.         0.         0.         0.         0.
 0.        

> Nested cross validation is used again, this time using area under receiver operating curve (AUROC) as the optimising variable for hyperparameters, and geometric mean to optimise thresholds.

In [15]:
def g_mean(y_test, y_pred):
    g_mean = np.sqrt(recall_score(y_test, y_pred) * recall_score(y_test, y_pred, pos_label=0))
    return g_mean

summary = "Summary of best models:\n\n"

for i in ['l1', 'l2', 'rf', 'gb', 'svm_linear', 'svm_rbf']:
    print("****", i.upper(), "****\n")
    cv_outer = StratifiedKFold(n_splits=4, shuffle=True, random_state=1)
    outer_results_def_thr = list()
    outer_results_best_thr = list()
    for train_ix, test_ix in cv_outer.split(X, y):
        X_train, X_test = X.iloc[train_ix, :], X.iloc[test_ix, :]
        y_train, y_test = y.iloc[train_ix], y.iloc[test_ix]
        cv_inner = StratifiedKFold(n_splits=3, shuffle=True, random_state=1)
        clf = GridSearchCV(pipelines[i], hyperparameters[i], cv=cv_inner, scoring='roc_auc', n_jobs=-1)
        clf.fit(X_train, y_train)
        gm_array = np.zeros(99)
        for itrain_ix, itest_ix in cv_inner.split(X_train, y_train):
            iX_train, iX_test = X_train.iloc[itrain_ix, :], X_train.iloc[itest_ix, :]
            iy_train, iy_test = y_train.iloc[itrain_ix], y_train.iloc[itest_ix]
            model = pipelines[i]
            model.set_params(**clf.best_params_)
            model.fit(iX_train, iy_train)
            if i in ['l1', 'l2', 'rf', 'gb']:
                pred = model.predict_proba(iX_test)
                pred = pred[:, 1]
            else:
                pred = model.decision_function(iX_test)
            for thr in range(1, 100):
                if i in ['l1', 'l2', 'rf', 'gb']:
                    y_pred = [int(p > thr / 100) for p in pred]
                else:
                    y_pred = [int(expit(p) > thr / 100) for p in pred]
                gm_array[int(thr - 1)] += g_mean(iy_test, y_pred)
        best_threshold = (np.argmax(gm_array) + 1) / 100
        print(gm_array, "\n")
        if i in ['svm_linear', 'svm_rbf']:
            best_threshold = logit(best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            pred = clf.predict_proba(X_test)
            pred = pred[:, 1]
            y_pred = [int(p > 0.5) for p in pred]
        else:
            pred = clf.decision_function(X_test)
            y_pred = [int(p > 0) for p in pred]
        print("Default threshold")
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(" G-mean:", g_mean(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_def_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr), 'g-mean': g_mean(y_test, y_pred),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
        print("Best threshold:", best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            y_pred = [int(p > best_threshold) for p in pred]
        else:
            y_pred = [int(expit(p) > best_threshold) for p in pred]
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(" G-mean:", g_mean(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_best_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr), 'g-mean': g_mean(y_test, y_pred),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
    print("Mean scores with default threshold:")
    for score in ['f1', 'AUROC', 'g-mean', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_def_thr]).mean())
    print("\nMean scores with best thresholds:")
    for score in ['f1', 'AUROC', 'g-mean', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_best_thr]).mean())
    print("\n")
    summary = summary + i + "\n  Default threshold\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_def_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_def_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_def_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_def_thr]).mean())[:6]
    summary = summary + "\n    G-mean: " + str(np.array([dict['g-mean'] for dict in outer_results_def_thr]).mean())
    summary = summary + "\n  Best thresholds\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_best_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_best_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_best_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_best_thr]).mean())[:6]
    summary = summary + "\n    G-mean: " + str(np.array([dict['g-mean'] for dict in outer_results_best_thr]).mean()) + "\n"

print(summary) 

**** L1 ****

[0.83464972 1.12527248 1.28276824 1.46975644 1.59795145 1.59795145
 1.59795145 1.59795145 1.59795145 1.59795145 1.74975701 1.74975701
 1.80185637 1.80185637 1.7777712  1.9060327  1.9060327  1.9060327
 1.9060327  1.9060327  1.9060327  1.87341237 1.91831229 1.98563916
 2.03238611 2.03238611 2.10336759 2.10336759 2.10336759 2.10336759
 2.10336759 2.10336759 2.07558405 2.09733977 2.18664279 2.15791664
 2.210016   2.18593082 2.22587831 2.22587831 2.22587831 2.27335008
 2.27335008 2.27335008 2.27335008 2.27335008 2.2421252  2.20967442
 2.20967442 2.17584092 2.17584092 2.17584092 2.14948314 2.14948314
 2.19634716 2.19634716 2.16909514 2.16909514 2.14431045 2.11846968
 2.19628    2.16561433 2.16561433 2.16561433 2.16561433 2.20761436
 2.20761436 2.20761436 2.20761436 2.14243295 2.14243295 2.14243295
 2.14243295 2.14243295 2.11019164 2.11019164 2.04699206 2.04699206
 2.04699206 2.07923336 2.07923336 2.04547158 1.97404309 1.8985568
 1.85919528 1.84931582 1.7680456  1.7230201  1.723

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.30151134 0.30151134 0.60302269 0.60302269 0.89649228
 0.89649228 1.01041693 1.01041693 1.01041693 1.01041693 1.01041693
 1.01041693 1.01041693 1.00237518 0.99068188 0.99068188 1.15774886
 1.14164232 1.12461153 1.12461153 1.12461153 1.12461153 1.24309267
 1.36122686 1.36122686 1.34918428 1.34918428 1.34918428 1.34918428
 1.34918428 1.40310371 1.49145602 1.54980376 1.54980376 1.62158731
 1.62158731 1.62158731 1.62158731 1.60194138 1.58710903 1.58710903
 1.65707499 1.62119898 1.62119898 1.62119898 1.66043762 1.72153576
 1.73492959 1.71269662 1.68963551 1.69233977 1.69049926 1.69049926
 1.66768871 1.68770095 1.68770095 1.68770095 1.69941219 1.71856402
 1.6896175  1.59893237 1.56718333 1.56718333 1.53217488 1.50310447
 1.54150234 1.50807076 1.50560516 1.3492574  1.3492574  1.35470115
 1.35470115 1.38813273 1.38813273 1.34973486 1.37286974 1.39674293
 1.20222088 1.2241024  1.16697544 1.1849993  1.09400215 1.02293159
 0.83970582 0.85038224 0.77746537 0.77746537 0.77746537 0.7774

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.30151134 0.60302269 0.60302269
 0.88365607 0.87595668 0.87595668 0.99234064 1.11723073 1.23878982
 1.23878982 1.23878982 1.23878982 1.23878982 1.23878982 1.23878982
 1.23878982 1.22799309 1.22799309 1.31729765 1.31729765 1.31729765
 1.31729765 1.41070268 1.41070268 1.4859899  1.56473398 1.56473398
 1.54890117 1.64217674 1.62650048 1.62650048 1.62650048 1.62650048
 1.69402508 1.69402508 1.77266002 1.77266002 1.72179686 1.72179686
 1.76730201 1.76730201 1.71078511 1.83182971 1.85965137 1.83537324
 1.86013328 1.8714916  1.80430908 1.77927067 1.79193898 1.85649652
 1.8905846  1.8905846  1.86752349 1.90759428 1.90759428 1.81971107
 1.78341706 1.78341706 1.67703354 1.65426124 1.60735929 1.60735929
 1.59425578 1.50156053 1.43511845 1.40604804 1.40604804 1.44444591
 1.29113474 1.25770315 1.25770315 1.25770315 1.25770315 1.28523506
 1.20543172 1.20543172 1.19462815 1.01173873 1.01173873 1.01173873
 1.01173873 1.01173873 1.01173873 1.01173873 0.85942538 0.8594

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.30151134
 0.30151134 0.30151134 0.30151134 0.30151134 0.60302269 0.72791278
 0.9486344  0.9486344  1.31809926 1.31809926 1.30201575 1.30201575
 1.37129453 1.37129453 1.37129453 1.37129453 1.43392732 1.49152413
 1.49152413 1.48382474 1.48382474 1.48382474 1.48382474 1.48382474
 1.47245198 1.58883595 1.57764046 1.71299593 1.71299593 1.85244141
 1.90137419 1.97235567 2.0348916  2.07862419 2.07862419 2.07862419
 2.11489921 2.11489921 2.17143602 2.1508312  2.14567524 2.05368961
 2.09820726 2.1541813  2.12713975 2.07650694 2.07650694 1.99329523
 2.0165406  2.04856984 2.05361411 2.05361411 2.0117054  1.9794641
 1.99169237 1.99169237 1.96390245 1.88478888 1.87977661 1.87977661
 1.79255706 1.73839936 1.68510626 1.54236427 1.37545961 1.29626349
 1.29626349 1.18063402 1.10771715 1.01269004 0.95121828 0.87830142
 0.64888568 0.55385858 0.45883147 0.22941573 0.22941573 0.22941573
 0.2

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 2.08664209 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.        ] 

Default threshold
 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.30151134 0.30151134 0.30151134 0.30151134 0.59018648 0.59018648
 0.59018648 0.59018648 0.59018648 0.59018648 0.59018648 0.59018648
 0.59018648 0.59018648 0.70975964 1.10302269 1.10302269 1.10302269
 1.10302269 1.10302269 1.32374431 1.32374431 1.52942412 1.52942412
 1.60677439 1.60677439 1.59137561 1.67137434 1.67137434 1.65610546
 1.65610546 1.65610546 1.6404292  1.6404292  1.7079538  1.78874352
 1.85320409 1.96568759 2.00399403 2.05855077 2.05855077 2.10320166
 2.14019364 2.09701359 2.07493427 2.14779377 2.07256085 2.12149968
 2.18212071 2.24238596 2.21303312 2.1232597  2.02048238 1.94981978
 1.83207355 1.75787931 1.63190952 1.47691466 1.1706822  1.13656367
 0.97009892 0.85498761 0.53496709 0.44234612 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


> The most favourable result overall appears to be the one using **L2-regularised logistic regression** with F1 score as the optimising variable. This process is then applied anew using 3-fold cross validation on the full dataset in order to produce a fitted model and a best threshold specific to this model.

In [16]:
i = 'l2'
results_best_thr = list()
cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=2)
clf = GridSearchCV(pipelines[i], hyperparameters[i], cv=cv, scoring='f1', n_jobs=-1)
clf.fit(X, y)
f1_array = np.zeros(99)
for train_ix, test_is in cv.split(X, y):
    X_train, X_test = X.iloc[train_ix, :], X.iloc[test_ix, :]
    y_train, y_test = y.iloc[train_ix], y.iloc[test_ix]
    model = pipelines[i]
    model.set_params(**clf.best_params_)
    model.fit(X_train, y_train)
    if i in ['l1', 'l2', 'rf', 'gb']:
        pred = model.predict_proba(X_test)
        pred = pred[:, 1]
    else:
        pred = model.decision_function(X_test)
    for thr in range(1, 100):
        if i in ['l1', 'l2', 'rf', 'gb']:
            y_pred = [int(p > thr / 100) for p in pred]
        else:
            y_pred = [int(expit(p) > thr / 100) for p in pred]
        f1_array[int(thr - 1)] += f1_score(y_test, y_pred)
best_threshold = (np.argmax(f1_array) + 1) / 100
print(f1_array, "\n")
if i in ['svm_linear', 'svm_rbf']:
    best_threshold = logit(best_threshold)
fitted_model_ncv = clf
best_thr_ncv = best_threshold
print("Best threshold for", i.upper(), "model =", best_thr_ncv)

[2.32653061 2.32653061 2.32653061 2.32653061 2.32653061 2.34268707
 2.34268707 2.35884354 2.35884354 2.375      2.39184397 2.40868794
 2.40868794 2.40868794 2.40868794 2.40868794 2.40868794 2.40868794
 2.42626426 2.44310823 2.44310823 2.46068455 2.46068455 2.46068455
 2.46068455 2.49661836 2.49661836 2.4705314  2.4705314  2.44426877
 2.44426877 2.44426877 2.44426877 2.44426877 2.46329626 2.48147808
 2.48147808 2.50050556 2.50050556 2.50050556 2.49260042 2.49260042
 2.49260042 2.49260042 2.46491493 2.46491493 2.46491493 2.45625692
 2.45625692 2.4760014  2.4760014  2.4760014  2.46689895 2.45564192
 2.42637363 2.42637363 2.44688645 2.46847889 2.4041514  2.34700855
 2.3037653  2.31042471 2.15953453 2.08849206 2.10866013 2.06944444
 2.01201771 1.93291789 1.91330645 1.88438543 1.90408986 1.90408986
 1.85117981 1.65842389 1.60669975 1.52667256 1.34920635 1.28860029
 1.23569024 0.93939394 0.93939394 0.86693017 0.79917184 0.81102955
 0.81102955 0.81102955 0.65367965 0.57142857 0.57142857 0.5714

> **WITHOUT ABSOLUTE POWERS**
> 
> In the following sections I've repeated all the above processes with absolute powers removed from the dataset to leave only  relative powers and lateralisation variables. The presumption I've made is that different testing conditions such as different machinery and different ground lead placement may lead to variation in absolute powers, whereas if these are all varied in a similar manner then relative powers should be more comparable across testing locations.

In [17]:
np.array(df.columns)

array(['subject', 'Fp1a delta', 'Fp1a theta', 'Fp1a alpha', 'Fp1a beta',
       'Fp1a gamma', 'Fp2a delta', 'Fp2a theta', 'Fp2a alpha',
       'Fp2a beta', 'Fp2a gamma', 'F3a delta', 'F3a theta', 'F3a alpha',
       'F3a beta', 'F3a gamma', 'F4a delta', 'F4a theta', 'F4a alpha',
       'F4a beta', 'F4a gamma', 'F7a delta', 'F7a theta', 'F7a alpha',
       'F7a beta', 'F7a gamma', 'F8a delta', 'F8a theta', 'F8a alpha',
       'F8a beta', 'F8a gamma', 'C3a delta', 'C3a theta', 'C3a alpha',
       'C3a beta', 'C3a gamma', 'C4a delta', 'C4a theta', 'C4a alpha',
       'C4a beta', 'C4a gamma', 'P3a delta', 'P3a theta', 'P3a alpha',
       'P3a beta', 'P3a gamma', 'P4a delta', 'P4a theta', 'P4a alpha',
       'P4a beta', 'P4a gamma', 'O1a delta', 'O1a theta', 'O1a alpha',
       'O1a beta', 'O1a gamma', 'O2a delta', 'O2a theta', 'O2a alpha',
       'O2a beta', 'O2a gamma', 'Fp1r delta', 'Fp1r theta', 'Fp1r alpha',
       'Fp1r beta', 'Fp1r gamma', 'Fp2r delta', 'Fp2r theta',
       'Fp2r alp

In [18]:
y = df.status
X = df.drop(['subject', 'Fp1a delta', 'Fp1a theta', 'Fp1a alpha', 'Fp1a beta',
       'Fp1a gamma', 'Fp2a delta', 'Fp2a theta', 'Fp2a alpha',
       'Fp2a beta', 'Fp2a gamma', 'F3a delta', 'F3a theta', 'F3a alpha',
       'F3a beta', 'F3a gamma', 'F4a delta', 'F4a theta', 'F4a alpha',
       'F4a beta', 'F4a gamma', 'F7a delta', 'F7a theta', 'F7a alpha',
       'F7a beta', 'F7a gamma', 'F8a delta', 'F8a theta', 'F8a alpha',
       'F8a beta', 'F8a gamma', 'C3a delta', 'C3a theta', 'C3a alpha',
       'C3a beta', 'C3a gamma', 'C4a delta', 'C4a theta', 'C4a alpha',
       'C4a beta', 'C4a gamma', 'P3a delta', 'P3a theta', 'P3a alpha',
       'P3a beta', 'P3a gamma', 'P4a delta', 'P4a theta', 'P4a alpha',
       'P4a beta', 'P4a gamma', 'O1a delta', 'O1a theta', 'O1a alpha',
       'O1a beta', 'O1a gamma', 'O2a delta', 'O2a theta', 'O2a alpha',
       'O2a beta', 'O2a gamma', 'status'], axis=1)

In [19]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1234, stratify=df.status)

> Cross validation is performed as before to produce a fitted model for each algorithm.

In [20]:
fitted_models_no_ap = {}
for i in ['l1', 'l2', 'rf', 'gb', 'svm_linear', 'svm_rbf']:
    model = GridSearchCV(pipelines[i], hyperparameters[i], cv=3, scoring='f1', n_jobs=-1)
    model.fit(X_train, y_train)
    fitted_models_no_ap[i] = model
    print(i, 'has been fitted')

l1 has been fitted
l2 has been fitted
rf has been fitted
gb has been fitted
svm_linear has been fitted
svm_rbf has been fitted


In [21]:
for name, model in fitted_models_no_ap.items():
    print(name, model.best_score_)

l1 0.703545650914072
l2 0.7780345707544899
rf 0.7447089947089948
gb 0.7540650406504065
svm_linear 0.6964076858813701
svm_rbf 0.7703401360544216


> Choosing the best model below is somewhat subjective. The support vector machine using radial basis function kernel (svm_rbf) gives the best F1 score but the model is useless because every subject is predicted as being alcoholic. I've chosen the **support vector machine** using a **linear kernel (svm_linear)** as this appears to be the most clinically useful model in terms of correctly identifying a non-alcoholic subject out of the 4 models with an equal highest accuracy of 0.68.

In [22]:
print("Scores for each model:\n")
for name, model in fitted_models_no_ap.items():
    if name in ['l1', 'l2', 'rf', 'gb']:
        pred = model.predict_proba(X_test)
        pred = [p[1] for p in pred]
        y_pred = [int(p > 0.5) for p in pred]
    else:
        pred = model.decision_function(X_test)
        y_pred = [int(p > 0) for p in pred]
    fpr, tpr, thresholds = roc_curve(y_test, pred)
    print(name, "- AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
    print(confusion_matrix(y_test, y_pred))
    print(classification_report(y_test, y_pred))
    print()

Scores for each model:

l1 - AUROC: 0.7818181818181819  F1 score: 0.6666666666666667
[[ 9  2]
 [ 9 11]]
              precision    recall  f1-score   support

           0       0.50      0.82      0.62        11
           1       0.85      0.55      0.67        20

    accuracy                           0.65        31
   macro avg       0.67      0.68      0.64        31
weighted avg       0.72      0.65      0.65        31


l2 - AUROC: 0.7363636363636364  F1 score: 0.761904761904762
[[ 5  6]
 [ 4 16]]
              precision    recall  f1-score   support

           0       0.56      0.45      0.50        11
           1       0.73      0.80      0.76        20

    accuracy                           0.68        31
   macro avg       0.64      0.63      0.63        31
weighted avg       0.67      0.68      0.67        31


rf - AUROC: 0.6522727272727272  F1 score: 0.7368421052631577
[[ 7  4]
 [ 6 14]]
              precision    recall  f1-score   support

           0       0.54   

  _warn_prf(average, modifier, msg_start, len(result))


> Nested cross validation is performed as previously.

In [23]:
summary = "Summary of best models:\n\n"
for i in ['l1', 'l2', 'rf', 'gb', 'svm_linear', 'svm_rbf']:
    print("****", i.upper(), "****\n")
    cv_outer = StratifiedKFold(n_splits=4, shuffle=True, random_state=1)
    outer_results_def_thr = list()
    outer_results_best_thr = list()
    for train_ix, test_ix in cv_outer.split(X, y):
        X_train, X_test = X.iloc[train_ix, :], X.iloc[test_ix, :]
        y_train, y_test = y.iloc[train_ix], y.iloc[test_ix]
        cv_inner = StratifiedKFold(n_splits=3, shuffle=True, random_state=1)
        clf = GridSearchCV(pipelines[i], hyperparameters[i], cv=cv_inner, scoring='f1', n_jobs=-1)
        clf.fit(X_train, y_train)
        f1_array = np.zeros(99)
        for itrain_ix, itest_ix in cv_inner.split(X_train, y_train):
            iX_train, iX_test = X_train.iloc[itrain_ix, :], X_train.iloc[itest_ix, :]
            iy_train, iy_test = y_train.iloc[itrain_ix], y_train.iloc[itest_ix]
            model = pipelines[i]
            model.set_params(**clf.best_params_)
            model.fit(iX_train, iy_train)
            if i in ['l1', 'l2', 'rf', 'gb']:
                pred = model.predict_proba(iX_test)
                pred = pred[:, 1]
            else:
                pred = model.decision_function(iX_test)
            for thr in range(1, 100):
                if i in ['l1', 'l2', 'rf', 'gb']:
                    y_pred = [int(p > thr / 100) for p in pred]
                else:
                    y_pred = [int(expit(p) > thr / 100) for p in pred]
                f1_array[int(thr - 1)] += f1_score(iy_test, y_pred)
        best_threshold = (np.argmax(f1_array) + 1) / 100
        print(f1_array, "\n")
        if i in ['svm_linear', 'svm_rbf']:
            best_threshold = logit(best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            pred = clf.predict_proba(X_test)
            pred = pred[:, 1]
            y_pred = [int(p > 0.5) for p in pred]
        else:
            pred = clf.decision_function(X_test)
            y_pred = [int(p > 0) for p in pred]
        print("Default threshold")
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_def_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
        print("Best threshold:", best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            y_pred = [int(p > best_threshold) for p in pred]
        else:
            y_pred = [int(expit(p) > best_threshold) for p in pred]
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_best_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
    print("Mean scores with default threshold:")
    for score in ['f1', 'AUROC', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_def_thr]).mean())
    print("\nMean scores with best thresholds:")
    for score in ['f1', 'AUROC', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_best_thr]).mean())
    print("\n")
    summary = summary + i + "\n  Default threshold\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_def_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_def_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_def_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_def_thr]).mean())[:6]
    summary = summary + "\n  Best thresholds\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_best_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_best_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_best_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_best_thr]).mean())[:6] + "\n"
print(summary) 

**** L1 ****

[2.35466962 2.31561462 2.27409988 2.29268293 2.31219512 2.30100063
 2.29129862 2.25826558 2.22777778 2.26923077 2.26923077 2.26923077
 2.26923077 2.26923077 2.23393665 2.23393665 2.23393665 2.25417957
 2.25417957 2.25417957 2.25417957 2.25417957 2.25417957 2.25417957
 2.25417957 2.25417957 2.25417957 2.25417957 2.21674641 2.21674641
 2.21674641 2.21674641 2.21674641 2.23469513 2.23469513 2.23469513
 2.23469513 2.25358852 2.25358852 2.25358852 2.25358852 2.25358852
 2.25358852 2.27492564 2.27492564 2.27492564 2.29765292 2.29765292
 2.26351351 2.26351351 2.26351351 2.26351351 2.26351351 2.26351351
 2.26351351 2.26351351 2.22319093 2.1901579  2.14714715 2.14714715
 2.14714715 2.14714715 2.14714715 2.14714715 2.16936937 2.16936937
 2.16936937 2.16936937 2.16936937 2.16936937 2.16936937 2.16936937
 2.16936937 2.16936937 2.18888889 2.18888889 2.18888889 2.18888889
 2.15238095 2.15238095 2.15238095 2.15238095 2.15238095 2.15238095
 2.11372549 2.16024341 2.16024341 2.16024341 2.1

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.3514906  2.3514906
 2.3514906  2.3514906  2.3514906  2.3514906  2.29895436 2.29895436
 2.29895436 2.29895436 2.31511083 2.31511083 2.31511083 2.31511083
 2.2870915  2.2870915  2.2870915  2.27365481 2.27365481 2.30618506
 2.30618506 2.32279635 2.29689441 2.29689441 2.23848238 2.21399259
 2.21399259 2.21399259 2.16188664 2.19403794 2.16575511 2.20089723
 2.20089723 2.18901764 2.20760068 2.22467385 2.24418605 2.24908425
 2.24908425 2.26932716 2.23912972 2.17470547 2.1957265  2.21367521
 2.14177979 2.12653377 2.13370548 2.13370548 2.13370548 1.94566993
 1.86233301 1.83463203 1.80921659 1.74545455 1.7661442  1.7661442
 1.71786834 1.67241379 1.62315271 1.62315271 1.51851852 1.51851852
 1.53703704 1.53703704 1.36798088 1.36798088 1.2086649  1.2086649
 1.09047619 1.1037037  1.13333333 1.13333333 1.13333333 1.13333333
 1.06666667 0.9942029  0.9942029  0.68944099 0.6942029  0.6942029

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.34623044 2.34623044 2.34623044 2.29041649 2.26051616 2.26051616
 2.22779923 2.19644034 2.19644034 2.19644034 2.19644034 2.19644034
 2.19644034 2.19644034 2.19644034 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351 2.21351351
 2.17897898 2.17897898 2.17897898 2.17897898 2.14444444 2.14444444
 2.14444444 2.10793651 2.10793651 2.10793651 2.10793651 2.06928105
 2.03594771 2.03594771 2.03594771 1.99494949 1.99494949 1.99494949
 1.99494949 2.01558442 2.01558442 2.03312828 2.03312828 2.0331

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.33533413 2.33533413 2.3514906
 2.3514906  2.3514906  2.34193784 2.34193784 2.35858909 2.34870766
 2.30230495 2.250199   2.26691332 2.27147414 2.29636591 2.3029525
 2.25952381 2.15160526 2.10567451 2.05815508 1.97402751 1.89338235
 1.72820513 1.64516129 1.55172414 1.3125937  1.14975845 1.01087801
 1.02797203 0.74857143 0.5942029  0.52173913 0.44268775 0.36363636
 0.0952381  0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.33598039 2.33598039 2.3514906
 2.3514906  2.32598039 2.32598039 2.3579932  2.3579932  2.37395064
 2.37395064 2.34824142 2.34824142 2.34824142 2.34824142 2.36489267
 2.33806566 2.33806566 2.3134058  2.26086957 2.29468599 2.33003953
 2.34646465 2.36443505 2.33564214 2.35281385 2.37184134 2.38181671
 2.4013289  2.42184172 2.39080259 2.42985843 2.40198387 2.34194634
 2.30955768 2.30955768 2.30955768 2.24765292 2.26560163 2.34795322
 2.34795322 2.26431978 2.26431978 2.2853408  2.21298701 2.21298701
 2.1362395  2.09702381 2.09702381 2.00952381 1.96124795 1.94566353
 1.94566353 1.72580645 1.68064516 1.4789272  1.38363297 1.23076923
 1.23076923 1.20128205 1.06215162 1.07753623 1.01086957 0.95652174
 0.95652174 0.88142292 0.79917184 0.79917184 0.53359684 0.53359

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31102041 2.2855102  2.30146765 2.30146765 2.31762411 2.31762411
 2.31762411 2.35166667 2.34166667 2.34166667 2.31507092 2.31507092
 2.28779819 2.30464217 2.32261257 2.35716488 2.35716488 2.39333449
 2.39333449 2.4109108  2.4109108  2.38164251 2.38164251 2.38164251
 2.35335968 2.35335968 2.35335968 2.370273   2.370273   2.34418605
 2.36469887 2.4005994  2.36956026 2.33684334 2.33684334 2.35542638
 2.35542638 2.32089185 2.33991933 2.38075881 2.400271   2.400271
 2.42222222 2.39145299 2.37371582 2.39435074 2.39435074 2.35737595
 2.35737595 2.35737595 2.35737595 2.35737595 2.32750397 2.30912162
 2.30912162 2.30912162 2.28828829 2.28828829 2.27777778 2.27777778
 2.2984127  2.18371487 2.10283251 2.10283251 2.10283251 2.10283251
 2.10283251 2.10283251 2.10283251 2.06921907 2.03356844 1.99568966
 1.99568966 1.97552836 1.97552836 1.93036707 1.85399415 1.7998631
 1.7998631  1.7998631  1.7998631  1.75685234 1.75685234 1.70512821
 1.64957265 1.64957265 1.54376658 1.54376658 1.55915119 1.4432234

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.36764706 2.34193784 2.37427536 2.36439394 2.35787455 2.32908163
 2.31717687 2.31717687 2.31717687 2.31717687 2.31717687 2.31717687
 2.33402084 2.3081189  2.3081189  2.28106576 2.28106576 2.26893939
 2.26893939 2.25988142 2.24474346 2.18395213 2.1556693  2.1556693
 2.12607099 2.12607099 2.14378971 2.10822783 2.12681087 2.12681087
 2.09580312 2.12913645 2.12913645 2.09159892 2.05826558 2.05826558
 2.05826558 2.05826558 2.02317786 2.02317786 2.02317786 2.02317786
 1.99269006 1.96063878 1.96063878 1.96063878 1.96063878 1.96063878
 1.92690058 1.9443609  1.92389307 1.84387439 1.84387439 1.84387439
 1.80831251 1.80831251 1.76629571 1.72173243 1.62398072 1.58644318
 1.58644318 1.55966387 1.55966387 1.51680672 1.51680672 1.51680672
 1.55143876 1.55143876 1.52110566 1.52110566 1.52110566 1.52110566
 1.491939   1.42926094 1.44708625 1.44708625 1.44708625 1.44708625
 1.44708625 1.39651153 1.41375291 1.32382134 1.21880342 1.21880342
 1.17823755 1.19301587 1.15190883 1.09615385 1.10702341 0.99420

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.29431373 2.28496505 2.28496505
 2.28496505 2.28496505 2.30027117 2.31595745 2.34856614 2.34856614
 2.34856614 2.34856614 2.34856614 2.34856614 2.34856614 2.34856614
 2.34856614 2.34856614 2.36489267 2.38154392 2.38154392 2.38154392
 2.38154392 2.38154392 2.43333333 2.40606061 2.40606061 2.40606061
 2.36342495 2.38139535 2.34391275 2.34391275 2.34391275 2.33463203
 2.29304029 2.26284285 2.20093809 2.21923077 2.21923077 2.23717949
 2.22193347 2.18954481 2.15749353 2.08961593 2.10810811 2.14714715
 2.11063921 2.09159159 1.99754902 1.9202381  1.87588326 1.83421659
 1.80529557 1.65767974 1.59785068 1.42045455 1.28152493 1.28152493
 1.28152493 1.28152493 1.23607038 1.23607038 1.14252199 1.17826541
 1.17826541 1.17826541 1.19435737 1.19435737 1.08196912 0.94179894
 0.94179894 0.88888889 0.83190883 0.70769231 0.72       0.72
 0.65333333 0.65333333 0.65333333 0.65333333 0.65333333 0.65333333
 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.36764706 2.36764706 2.36764706
 2.36764706 2.36764706 2.36764706 2.36764706 2.34333333 2.34333333
 2.33446809 2.32673828 2.29991127 2.29991127 2.29991127 2.27382431
 2.27382431 2.28913043 2.30555556 2.30555556 2.321513   2.321513
 2.321513   2.321513   2.33816425 2.32828283 2.3        2.34069767
 2.32753978 2.32753978 2.34636591 2.33551173 2.35502392 2.33991201
 2.33991201 2.28319051 2.24217639 2.22518797 2.22518797 2.19313669
 2.19313669 2.19313669 2.15448123 2.12074303 2.08802611 2.10651828
 2.10651828 2.06552007 2.08445946 1.98858167 1.98858167 1.95238095
 1.95238095 1.91071429 1.86635945 1.86635945 1.78039216 1.59583333
 1.48970307 1.48970307 1.50784823 1.52508961 1.52508961 1.5146871
 1.47741148 1.42450142 1.36752137 1.30598291 1.30598291 1.23931624
 1.11794872 1.11794872 1.05333333 0.98333333 0.93115942 0.8629776
 0.79051383 0.62488236 0.62488236 0.54978355 0.54978355 0.54978355

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31102041 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041
 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041
 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041
 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041 2.31102041
 2.31102041 2.31102041 2.30187075 2.30187075 2.30187075 2.30187075
 2.29146765 2.30762411 2.30762411 2.32446809 2.32446809 2.32446809
 2.31521739 2.32330637 2.37483694 2.34604403 2.31921701 2.30735931
 2.30735931 2.24675325 2.30387597 2.30387597 2.30387597 2.32438879
 2.29186846 2.35829826 2.35829826 2.32676673 2.34871795 2.36491228
 2.37671882 2.3013923  2.3013923  2.26986077 2.27807487 2.16645022
 2.04435484 1.88735632 1.65772669 1.51703704 1.51703704 1.34153846
 1.13907469 1.08939394 0.95388669 0.87878788 0.64502165 0.55454545
 0.47272727 0.29047619 0.29047619 0.2        0.1        0.1
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413 2.33533413
 2.30102041 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.3355102  2.3355102
 2.3355102  2.35183673 2.35183673 2.38435374 2.38435374 2.35953105
 2.35953105 2.34978557 2.2944496  2.26874038 2.26874038 2.2856537
 2.30230495 2.27240461 2.23449025 2.23962504 2.27471273 2.21666667
 2.18461538 2.04637592 1.90465449 1.87073525 1.67857143 1.43387647
 1.28296296 1.25641026 1.19794872 1.         0.54620742 0.45573123
 0.28181818 0.1        0.1        0.1        0.1        0.1
 0.         0.         0.         0.         0.         0.
 0.         

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [24]:
def g_mean(y_test, y_pred):
    g_mean = np.sqrt(recall_score(y_test, y_pred) * recall_score(y_test, y_pred, pos_label=0))
    return g_mean

summary = "Summary of best models:\n\n"

for i in ['l1', 'l2', 'rf', 'gb', 'svm_linear', 'svm_rbf']:
    print("****", i.upper(), "****\n")
    cv_outer = StratifiedKFold(n_splits=4, shuffle=True, random_state=1)
    outer_results_def_thr = list()
    outer_results_best_thr = list()
    for train_ix, test_ix in cv_outer.split(X, y):
        X_train, X_test = X.iloc[train_ix, :], X.iloc[test_ix, :]
        y_train, y_test = y.iloc[train_ix], y.iloc[test_ix]
        cv_inner = StratifiedKFold(n_splits=3, shuffle=True, random_state=1)
        clf = GridSearchCV(pipelines[i], hyperparameters[i], cv=cv_inner, scoring='roc_auc', n_jobs=-1)
        clf.fit(X_train, y_train)
        gm_array = np.zeros(99)
        for itrain_ix, itest_ix in cv_inner.split(X_train, y_train):
            iX_train, iX_test = X_train.iloc[itrain_ix, :], X_train.iloc[itest_ix, :]
            iy_train, iy_test = y_train.iloc[itrain_ix], y_train.iloc[itest_ix]
            model = pipelines[i]
            model.set_params(**clf.best_params_)
            model.fit(iX_train, iy_train)
            if i in ['l1', 'l2', 'rf', 'gb']:
                pred = model.predict_proba(iX_test)
                pred = pred[:, 1]
            else:
                pred = model.decision_function(iX_test)
            for thr in range(1, 100):
                if i in ['l1', 'l2', 'rf', 'gb']:
                    y_pred = [int(p > thr / 100) for p in pred]
                else:
                    y_pred = [int(expit(p) > thr / 100) for p in pred]
                gm_array[int(thr - 1)] += g_mean(iy_test, y_pred)
        best_threshold = (np.argmax(gm_array) + 1) / 100
        print(gm_array, "\n")
        if i in ['svm_linear', 'svm_rbf']:
            best_threshold = logit(best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            pred = clf.predict_proba(X_test)
            pred = pred[:, 1]
            y_pred = [int(p > 0.5) for p in pred]
        else:
            pred = clf.decision_function(X_test)
            y_pred = [int(p > 0) for p in pred]
        print("Default threshold")
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(" G-mean:", g_mean(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_def_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr), 'g-mean': g_mean(y_test, y_pred),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
        print("Best threshold:", best_threshold)
        if i in ['l1', 'l2', 'rf', 'gb']:
            y_pred = [int(p > best_threshold) for p in pred]
        else:
            y_pred = [int(expit(p) > best_threshold) for p in pred]
        fpr, tpr, thresholds = roc_curve(y_test, pred)
        print(" AUROC:", auc(fpr, tpr), " F1 score:", f1_score(y_test, y_pred))
        print(" G-mean:", g_mean(y_test, y_pred))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))
        print()
        outer_results_best_thr.append({'f1': f1_score(y_test, y_pred), 'AUROC': auc(fpr, tpr), 'g-mean': g_mean(y_test, y_pred),
                                      'sensitivity': recall_score(y_test, y_pred), 'specificity': recall_score(y_test, y_pred, pos_label=0),
                                      'PPV': precision_score(y_test, y_pred), 'NPV': precision_score(y_test, y_pred, pos_label=0)})
    print("Mean scores with default threshold:")
    for score in ['f1', 'AUROC', 'g-mean', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_def_thr]).mean())
    print("\nMean scores with best thresholds:")
    for score in ['f1', 'AUROC', 'g-mean', 'sensitivity', 'specificity', 'PPV', 'NPV']:
        print(" ", score, ":", np.array([dict[score] for dict in outer_results_best_thr]).mean())
    print("\n")
    summary = summary + i + "\n  Default threshold\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_def_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_def_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_def_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_def_thr]).mean())[:6]
    summary = summary + "\n    G-mean: " + str(np.array([dict['g-mean'] for dict in outer_results_def_thr]).mean())
    summary = summary + "\n  Best thresholds\n    AUROC: " + str(np.array([dict['AUROC'] for dict in outer_results_best_thr]).mean()) + "  F1 score: " + str(np.array([dict['f1'] for dict in outer_results_best_thr]).mean())
    summary = summary + "\n    Sensitivity: " + str(np.array([dict['sensitivity'] for dict in outer_results_best_thr]).mean())[:6] + "  Specificity: " + str(np.array([dict['specificity'] for dict in outer_results_best_thr]).mean())[:6]
    summary = summary + "\n    G-mean: " + str(np.array([dict['g-mean'] for dict in outer_results_best_thr]).mean()) + "\n"

print(summary) 

**** L1 ****

[1.96877233 2.07378745 2.10484486 2.08603532 2.08603532 2.14077685
 2.11947296 2.11947296 2.09222094 2.09222094 2.06397697 2.06397697
 2.06397697 2.06397697 2.06397697 2.06397697 2.06397697 2.06397697
 2.03994295 2.03994295 2.03994295 2.03994295 2.05786952 2.05786952
 2.05786952 2.05786952 2.05786952 2.05786952 2.05786952 2.05786952
 2.05786952 2.05786952 2.05786952 2.05786952 2.05786952 2.05786952
 2.05786952 2.05786952 2.05786952 2.05786952 2.05786952 2.02726596
 2.07589905 2.07589905 2.07589905 2.07589905 2.15839467 2.15839467
 2.15839467 2.15839467 2.15839467 2.15839467 2.15839467 2.15839467
 2.15839467 2.15839467 2.15839467 2.15839467 2.15839467 2.15839467
 2.15839467 2.15839467 2.15839467 2.15839467 2.15839467 2.15839467
 2.15839467 2.15839467 2.12463289 2.12463289 2.12463289 2.12463289
 2.12463289 2.12463289 2.12463289 2.12463289 2.12463289 2.12463289
 2.16903867 2.16903867 2.16903867 2.16903867 2.16903867 2.16903867
 2.16903867 2.16903867 2.16903867 2.16903867 2.1

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.28867513 0.70975964 1.08205251 1.55247675
 1.86262058 2.13944012 1.98678247 2.01717711 1.80880928 1.54846984
 1.39218168 1.13540735 0.69509538 0.32444284 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.        ] 

Def

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.30151134
 0.42640143 0.91725802 0.90368198 0.9367196  1.35944931 1.54129509
 1.58736621 2.13560066 2.0750634  1.84882879 1.74376037 1.59314734
 1.45342674 1.15073117 0.68718097 0.43884269 0.22941573 0.22941573
 0.22941573 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.30151134 0.30151134 0.30151134 0.30151134 0.30151134 0.42640143
 0.8109081  1.11241945 1.11241945 1.19320917 1.19320917 1.17746803
 1.16138452 1.16138452 1.16138452 1.16138452 1.16138452 1.2306633
 1.2306633  1.35222238 1.44549795 1.38832229 1.50470625 1.59401081
 1.67264575 1.67264575 1.73210016 1.79656073 1.79656073 1.79656073
 1.8658395  1.81031086 1.96059067 1.9177486  1.89566927 1.92136272
 1.92136272 1.87461746 1.9482844  1.92533823 1.92533823 1.92533823
 1.96733826 2.01436586 2.07431025 2.0445385  2.05882226 2.09650805
 2.07289972 2.04538473 1.93369488 1.93369488 1.80892057 1.76896219
 1.79962786 1.75669773 1.73931468 1.73931468 1.73931468 1.58010399
 1.38428478 1.16553529 1.03114666 1.03114666 0.93611955 0.70670382
 0.70670382 0.64888568 0.55385858 0.55385858 0.55385858 0.22941573
 0.22941573 0.229415

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.29346959 0.29346959
 0.29346959 0.29346959 0.57444534 0.87595668 1.11389973 1.11389973
 1.11389973 1.11389973 1.11389973 1.11389973 1.11389973 1.11389973
 1.11389973 1.11389973 1.23878982 1.32809437 1.32809437 1.32809437
 1.32809437 1.32809437 1.5924887  1.57816739 1.57816739 1.57816739
 1.6883133  1.75393541 1.7770993  1.7770993  1.7770993  1.81299632
 1.83656973 1.81899788 1.78202032 1.83676185 1.83676185 1.89785998
 1.93003582 1.90661035 1.88530646 1.838949   1.88581302 1.97383896
 1.94734301 1.95740385 1.9158691  1.86105274 1.82881144 1.7975932
 1.79390251 1.68839104 1.64369477 1.51064452 1.39193106 1.39193106
 1.39193106 1.39193106 1.36286065 1.36286065 1.29731131 1.36444747
 1.36444747 1.36444747 1.39787905 1.39787905 1.31772599 1.06993417
 1.06993417 1.02908662 0.9852012  0.89129279 0.91502868 0.91502868
 0.86224227 0.86224227 0.86224227 0.86224227 0.86224227 0.86224227
 0.8

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.30151134 0.30151134 0.30151134 0.30151134 0.30151134 0.57907809
 0.69721228 0.69721228 0.78785984 0.78785984 1.16579061 1.16579061
 1.16579061 1.16579061 1.15774886 1.15774886 1.14948045 1.21680731
 1.21680731 1.20896918 1.18992818 1.28796074 1.28796074 1.26831481
 1.35625586 1.35625586 1.47139883 1.45108638 1.53943869 1.59920078
 1.57814897 1.57814897 1.57814897 1.64346562 1.62158731 1.69384783
 1.65901292 1.65308085 1.65308085 1.68598148 1.66099378 1.68840389
 1.68840389 1.68840389 1.67581088 1.67581088 1.6286073  1.60672899
 1.58391845 1.58391845 1.56004526 1.56004526 1.60297538 1.57327687
 1.57750721 1.61022488 1.61022488 1.57882525 1.57882525 1.58856098
 1.6314911  1.6314911  1.56836101 1.56836101 1.56836101 1.56836101
 1.56209674 1.56209674 1.53069711 1.49726553 1.45409885 1.44958056
 1.50499647 1.53469497 1.58621082 1.49222301 1.43943661 1.29733785
 1.24491456 1.20399868 1.20399868 1.09400215 0.93031062 0.83970582
 0.

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.30151134 0.30151134 0.30151134 0.59538825
 0.87636399 1.11763805 1.10979992 1.10979992 1.10979992 1.10979992
 1.09054538 1.30151989 1.38830832 1.38830832 1.38830832 1.38830832
 1.54226407 1.54226407 1.60672463 1.60672463 1.60672463 1.69507694
 1.67899344 1.84880892 1.89286056 2.01951435 1.97692838 2.12071906
 2.12071906 2.19721182 2.12578369 2.09874215 2.07063895 2.04407697
 1.98793855 1.95863788 1.98514787 1.98514787 1.98514787 1.95531921
 1.95531921 1.95531921 1.92307791 1.92307791 1.90082673 1.84654738
 1.81245931 1.75219882 1.7132056  1.6882179  1.68839104 1.6108725
 1.58187728 1.54717496 1.48832023 1.43175435 1.43175435 1.50109735
 1.46766577 1.46269948 1.42231974 1.42231974 1.33133821 1.29049066
 1.29049066 1.29049066 1.24873546 1.24873546 1.24873546 1.24873546
 1.24873546 1.21461693 1.16689437 1.00800686 0.70752056 0.70752056
 0.65509726 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.28867513 0.40824829 0.40824829
 0.40824829 0.40824829 0.40824829 0.40824829 0.40824829 0.40824829
 0.70975964 1.12478831 1.12478831 1.12478831 1.12478831 1.12478831
 1.11715388 1.23888168 1.33063339 1.40798365 1.49017443 1.6320188
 1.6320188  1.71065374 1.67472176 1.75136562 1.75136562 1.75136562
 1.75136562 1.73255607 1.7872976  1.8548222  1.82026478 1.86855065
 1.90657369 1.90657369 1.9414845  1.87102779 1.87102779 1.82367076
 1.92660736 1.96466291 1.97776031 1.97776031 1.99521105 1.96481366
 1.9780012  1.91693495 1.88826585 1.82108283 1.85831188 1.85831188
 1.85831188 1.8245501  1.78905359 1.78489813 1.7473684  1.73861924
 1.73861924 1.65486743 1.58760897 1.58760897 1.53482257 1.53482257
 1.49582935 1.49582935 1.43591408 1.29039146 1.29039146 1.29039146
 1.12878164 1.14739617 1.14739617 1.00791562 0.93684505 0.85009756
 0.54804964 

> Again the choice of a best process is subjective, and **L1-regularised linear regression** optimised for AUROC (and with thresholds optimised for geometric mean) is chosen on the basis of having a sensitivity and a specificity which are both of some clinical use.

In [25]:
i = 'l1'
results_best_thr = list()
cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=2)
clf = GridSearchCV(pipelines[i], hyperparameters[i], cv=cv, scoring='roc_auc', n_jobs=-1)
clf.fit(X, y)
gm_array = np.zeros(99)
for train_ix, test_is in cv.split(X, y):
    X_train, X_test = X.iloc[train_ix, :], X.iloc[test_ix, :]
    y_train, y_test = y.iloc[train_ix], y.iloc[test_ix]
    model = pipelines[i]
    model.set_params(**clf.best_params_)
    model.fit(X_train, y_train)
    if i in ['l1', 'l2', 'rf', 'gb']:
        pred = model.predict_proba(X_test)
        pred = pred[:, 1]
    else:
        pred = model.decision_function(X_test)
    for thr in range(1, 100):
        if i in ['l1', 'l2', 'rf', 'gb']:
            y_pred = [int(p > thr / 100) for p in pred]
        else:
            y_pred = [int(expit(p) > thr / 100) for p in pred]
        gm_array[int(thr - 1)] += g_mean(y_test, y_pred)
best_threshold = (np.argmax(f1_array) + 1) / 100
print(f1_array, "\n")
if i in ['svm_linear', 'svm_rbf']:
    best_threshold = logit(best_threshold)
fitted_model_no_ap_ncv = clf
best_thr_no_ap_ncv = best_threshold
print("Best threshold for", i.upper(), "model =", best_thr_no_ap_ncv)

[2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393 2.31982393
 2.31982393 2.31982393 2.31982393 2.31982393 2.3355102  2.3355102
 2.3355102  2.35183673 2.35183673 2.38435374 2.38435374 2.35953105
 2.35953105 2.34978557 2.2944496  2.26874038 2.26874038 2.2856537
 2.30230495 2.27240461 2.23449025 2.23962504 2.27471273 2.21666667
 2.18461538 2.04637592 1.90465449 1.87073525 1.67857143 1.43387647
 1.28296296 1.25641026 1.19794872 1.         0.54620742 0.45573123
 0.28181818 0.1        0.1        0.1        0.1        0.1
 0.         0.         0.         0.         0.         0.
 0.         

> Hyperparameters of the best models are shown below, then the models and best thresholds are saved for further use using the Pickle module.

In [26]:
fitted_models['l2'].best_estimator_

Pipeline(memory=None,
         steps=[('standardscaler',
                 StandardScaler(copy=True, with_mean=True, with_std=True)),
                ('logisticregression',
                 LogisticRegression(C=0.0031622776601683794, class_weight=None,
                                    dual=False, fit_intercept=True,
                                    intercept_scaling=1, l1_ratio=None,
                                    max_iter=100, multi_class='auto',
                                    n_jobs=None, penalty='l2', random_state=123,
                                    solver='lbfgs', tol=0.0001, verbose=0,
                                    warm_start=False))],
         verbose=False)

In [27]:
fitted_model_ncv.best_estimator_

Pipeline(memory=None,
         steps=[('standardscaler',
                 StandardScaler(copy=True, with_mean=True, with_std=True)),
                ('logisticregression',
                 LogisticRegression(C=0.01, class_weight=None, dual=False,
                                    fit_intercept=True, intercept_scaling=1,
                                    l1_ratio=None, max_iter=100,
                                    multi_class='auto', n_jobs=None,
                                    penalty='l2', random_state=123,
                                    solver='lbfgs', tol=0.0001, verbose=0,
                                    warm_start=False))],
         verbose=False)

In [28]:
fitted_models_no_ap['svm_linear'].best_estimator_

Pipeline(memory=None,
         steps=[('standardscaler',
                 StandardScaler(copy=True, with_mean=True, with_std=True)),
                ('linearsvc',
                 LinearSVC(C=0.1, class_weight=None, dual=True,
                           fit_intercept=True, intercept_scaling=1,
                           loss='squared_hinge', max_iter=1000,
                           multi_class='ovr', penalty='l2', random_state=123,
                           tol=0.0001, verbose=0))],
         verbose=False)

In [29]:
fitted_model_no_ap_ncv.best_estimator_

Pipeline(memory=None,
         steps=[('standardscaler',
                 StandardScaler(copy=True, with_mean=True, with_std=True)),
                ('logisticregression',
                 LogisticRegression(C=1.0, class_weight=None, dual=False,
                                    fit_intercept=True, intercept_scaling=1,
                                    l1_ratio=None, max_iter=100,
                                    multi_class='auto', n_jobs=None,
                                    penalty='l1', random_state=123,
                                    solver='liblinear', tol=0.0001, verbose=0,
                                    warm_start=False))],
         verbose=False)

In [30]:
with open("eeg_model.pkl", "wb") as f:
    pickle.dump(fitted_models['l2'].best_estimator_, f)
with open("eeg_model_ncv.pkl", "wb") as f:
    pickle.dump(fitted_model_ncv.best_estimator_, f)
with open("eeg_best_thr_ncv.pkl", "wb") as f:
    pickle.dump(best_thr_ncv, f)
with open("eeg_model_no_ap.pkl", "wb") as f:
    pickle.dump(fitted_models_no_ap['svm_linear'].best_estimator_, f)
with open("eeg_model_no_ap_ncv.pkl", "wb") as f:
    pickle.dump(fitted_model_no_ap_ncv.best_estimator_, f)
with open("eeg_best_thr_no_ap_ncv.pkl", "wb") as f:
    pickle.dump(best_thr_no_ap_ncv, f)