In [77]:
import numpy as np
import mne
import matplotlib.pyplot as plt
import pickle
from src.features import feature_extractor as fe

In [78]:
from sklearn.model_selection import train_test_split, GridSearchCV, PredefinedSplit
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn import metrics
from sklearn.svm import SVC

In [79]:
Chosen_set_dir = "data/processed/1.0sec_overlap0.5_fp1"

# read metadata
loaded_metadata = None
with open(Chosen_set_dir + "/metadata.pkl", 'rb') as file:
    loaded_metadata = pickle.load(file)

loaded_dataset = []
for subject_idx, subject_id in enumerate(loaded_metadata['subjects']):
    loaded_dataset.append(mne.read_epochs(Chosen_set_dir + f'/sub{subject_id}_epo.fif', verbose = False))

In [80]:
# feature_set = fe.extract_feature_set(['mean', 'rms', 'peak_to_peak', 'kurtosis', 'hjorth_complexity'], loaded_dataset, loaded_metadata)
feature_set = fe.extract_feature_set(['rms'], loaded_dataset, loaded_metadata)
# feature_set = fe.extract_feature_set(['peak_to_peak'], loaded_dataset, loaded_metadata)
# feature_set = fe.extract_feature_set(['kurtosis'], loaded_dataset, loaded_metadata)
# feature_set = fe.extract_feature_set(['hjorth_complexity'], loaded_dataset, loaded_metadata)

In [81]:
# Make dataset for training
dataset = []
labelset = []

for subj_idx, subj_id in enumerate(loaded_metadata['subjects']):
    subj_dataset = []
    subj_labelset = []
    for cls_idx, cls in enumerate(loaded_metadata['classes']):
        if feature_set[subj_idx][cls_idx]:
            subj_dataset.append(np.array(feature_set[subj_idx][cls_idx]))
            for epoch in feature_set[subj_idx][cls_idx]:
                subj_labelset.append(cls_idx)
    dataset.append(np.concatenate(subj_dataset, axis = 0))
    labelset.append(subj_labelset)

dataset = np.array(dataset)
dataset = dataset.transpose((0, 1, 3, 2)).reshape(dataset.shape[0], dataset.shape[1], -1)
labelset = np.array(labelset)

In [82]:
x, x_test, y, y_test = train_test_split(
    dataset.reshape(-1, dataset.shape[2]), labelset.reshape(-1), test_size=0.2, random_state=41)
x_train, x_val, y_train, y_val = train_test_split(
    x, y, test_size=0.25, random_state=1)

# SVM classifier

In [83]:
param_grid = {
    'C': [0.1, 1, 10, 100, 1000],
    'kernel': ['rbf']
}
split_index = [-1 if x in range(len(x_train)) else 0 for x in range(len(x))]
ps = PredefinedSplit(test_fold=split_index)
svm_clf = GridSearchCV(SVC(), param_grid, cv=ps, refit=True)
svm_clf.fit(x, y)

In [84]:
y_pred = svm_clf.predict(x_test)
y_true = y_test

In [85]:
print(metrics.classification_report(y_true, y_pred))
print(metrics.confusion_matrix(y_true, y_pred))

              precision    recall  f1-score   support

           1       0.79      0.84      0.81       602
           2       0.82      0.76      0.79       574

    accuracy                           0.80      1176
   macro avg       0.80      0.80      0.80      1176
weighted avg       0.80      0.80      0.80      1176

[[506  96]
 [135 439]]
