In [1]:
from database.data import Data
from sklearn.svm import SVC
from tqdm import tqdm
import numpy as np

from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from models.scoring_metrics import get_all_metrics, scoring_function

In [2]:
#creating DATA object
ptb_binary_SVM = Data(database = 'ptbdb', denoise_method='DWT', estimation_method = 'SVM', train_splits=None, binary = True, parameterisation = True)



In [3]:
ptb_binary_SVM.run()

Filtering Database


100%|██████████| 549/549 [00:58<00:00,  9.31it/s]
100%|██████████| 221/221 [00:01<00:00, 132.77it/s]


denoising signals through Discrete Wavelet Transform
normalising signals
calculating time domain parameters
calculating frequency domain parameters
calculating non linear domain parameters
selecting 4 most important features
Selected features for channel 1:
['rr_amps', 'mean', 'lf', 'age']
Selected features for channel 2:
['rr_amps', 'shannon_en', 'sd2', 'age']
Selected features for channel 3:
['rr_std', 'sd1', 'sd2', 'age']
Selected features for channel 4:
['rr_amps', 'std', 'sd_ratio', 'age']
Selected features for channel 5:
['skews', 'sd1', 'sd2', 'age']
Selected features for channel 6:
['RMSSD', 'mean', 'sd2', 'age']


In [4]:
input_data = ptb_binary_SVM.input_data
labels = ptb_binary_SVM.labels

labels_encoded = []

for i in range(6):
    encoded = [0 if label == 'Unhealthy' else 1 for label in labels[i]]
    labels_encoded.append(np.array(encoded))



## Use this for average channel scores 

In [25]:
# define hyperparameter grid to test
param_grid = {
    'C': [0.01, 0.1, 1, 10],
    'kernel': ['linear', 'rbf', 'poly'],
    'gamma': ['scale']#including 'auto' aswell takes forever
}

y_tests_list = []
probs_list = []
metrics = []
thresholds_list = []

for i in range(6):

    #define classifier
    svc = SVC(class_weight='balanced', probability = True)

    #find the best set of hyperparameters for each channel, tuned on the desired scoring function
    best_svc = tune_hyperparams(input_data[i], labels_encoded[i], param_grid, svc, scorer='balanced_accuracy')

    #perform 3 way skfold to get scores for each channel as well as their probabilities
    n_splits = 3
    all_score_metrics, thresholds, probabilities, y_tests, test_indices = perform_skfold(input_data[i], labels_encoded[i], n_splits, best_svc, get_probabilities=True)

    #monitoring scores for each channel
    metrics.append(all_score_metrics)

    #calculating average threshold from all splits
    threshold = np.mean(thresholds)

    #reconstructing calculated probabilities so can optimise over all channels
    reconstructed_probs = reconstruct_probs(probabilities, test_indices, ptb_binary_SVM.nan_indices[i], ptb_binary_SVM.allowed_patients.count_patients(), n_splits)

    #need for ROC curve and confusion matrix later
    probs_list.append(reconstructed_probs)
    y_tests_list.append(y_tests)
    thresholds_list.append(threshold)


In [26]:
from sklearn.metrics import classification_report

for i in range(6):
    print(metrics[i].items())
    #print(classification_report(y_tests[i], y_preds[i]))

dict_items([('F1 score', 0.3089130113720278), ('Objective score', 0.35734274941700206), ('Bal Acc', 0.4703454715219421), ('Accuracy', 0.488966588966589), ('precision', 0.24318435188000406), ('recall', 0.4365079365079365)])
dict_items([('F1 score', 0.514484126984127), ('Objective score', 0.5717442400594575), ('Bal Acc', 0.7053511705685619), ('Accuracy', 0.6986167932982661), ('precision', 0.40554363966342183), ('recall', 0.717948717948718)])
dict_items([('F1 score', 0.5198830409356724), ('Objective score', 0.5820458342826763), ('Bal Acc', 0.7270923520923521), ('Accuracy', 0.6876310272536688), ('precision', 0.3888076673164392), ('recall', 0.7954545454545455)])
dict_items([('F1 score', 0.4721254355400697), ('Objective score', 0.5310047255152285), ('Bal Acc', 0.6683897354572655), ('Accuracy', 0.7055555555555556), ('precision', 0.4005291005291005), ('recall', 0.6007326007326007)])
dict_items([('F1 score', 0.5125432629192027), ('Objective score', 0.5742564745196325), ('Bal Acc', 0.71825396825