In [2]:
!pip install dtaidistance
from dtaidistance import dtw
from dtaidistance import dtw_visualisation as dtwvis
import numpy as np
import pickle
import os
import pandas as pd
import random
import librosa
from scipy.stats import kurtosis, skew
import warnings
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score,f1_score,confusion_matrix,roc_auc_score,ConfusionMatrixDisplay,precision_score,recall_score
from preprocess import preprocess

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
# get file list for 1 patient
def get_file_list(path):
  file_list = []
  label = []
  for i in os.listdir(path):
    if i[0] == 'E':
      file_list.append(i)
      l = int(i.split("_")[-1].split(".")[0])-1
      label.append(l)
  return label,file_list

In [4]:
# get feature from 1 file and preprocess
def get_feature(path):
    f = []
    a = pd.read_csv(path,names=["vertical","horizontal"])
    a = np.array(a)
    #print(a.shape)
    a = preprocess(a)
    for j in a[:,0]:
        f.append(j)        
    for j in a[:,1]:
        f.append(j)
    return f

In [5]:
# self identified test_split
def my_train_test_split_user_dependent(path,test_split,val_split,file_list,label):
    X_test = []
    X_train = []
    X_val = []
    y_val = []
    y_train = []
    y_test = []
    for f in range(len(file_list)):
        file = file_list[f]
        file_label = label[f]
        #print(file)
        feature = get_feature(str(path+file))
        #print(file.split('_')[2],file_label)
        if file.split('_')[2] in test_split:
            X_test.append(feature)
            y_test.append(file_label)
            #print(file,len(feature))
        elif file.split('_')[2] == val_split:
            X_val.append(feature)
            y_val.append(file_label)
        else:
            X_train.append(feature)
            y_train.append(file_label)
            #print(file)


    X_train = np.array(X_train)
    X_test = np.array(X_test)
    X_val = np.array(X_val)
    y_train = np.array(y_train)
    y_val = np.array(y_val)
    y_test = np.array(y_test)

    
    return X_train,X_test,X_val,y_val, y_train,y_test

In [6]:
def evaluate(y_true, y_pred):
    f1_micro = f1_score(y_true, y_pred,average = 'micro')
    f1_macro = f1_score(y_true, y_pred,average = 'macro')
    precision_micro = precision_score(y_true, y_pred, average='micro')
    precision_macro = precision_score(y_true, y_pred, average='macro')
    recall_micro = recall_score(y_true, y_pred, average='micro')
    recall_macro = recall_score(y_true, y_pred, average='macro')
    acc = accuracy_score(y_true, y_pred)

    return f1_micro,f1_macro,precision_micro,precision_macro,recall_micro,recall_macro,acc

# User dependent

In [7]:
patient = ["001","002","003","004","005","006"]
test_split = [["01","02"],["03","04"],["05","06"],["07","08"],["09","10"]]
val_split = ['03','01',"04","05","06"]

In [8]:
class DTW_clf():
  def __init__(self, template, num_class):
    self.template = template
    self.num_class = num_class

  def train(self, X, y):
    for i in range(X.shape[0]):
      label = y[i]
      self.template[label].append(list(X[i,:]))

  def predict(self, X):
    class_dist = np.zeros(self.num_class)
    for k in range(self.num_class):
      count = 0
      distance = 0
      for templ in self.template[k]:
        s1 = np.array(templ)
        distance += dtw.distance_fast(s1, X,use_pruning=True)
        count += 1
      class_dist[k] = distance/count
    predicted_class = np.argmin(class_dist)
    return predicted_class

In [15]:
for p in patient:
  path = str("drive/MyDrive/Colab_Notebooks/EOG_data/isolated/"+p+"/isolated_strokes/")
  label,file_list = get_file_list(path)
  val_pred = []
  val_label = []
  test_pred = []
  test_label = []
  num_class = 12
  for t in range(len(val_split)):
    X_train,X_test,X_val,y_val, y_train,y_test= my_train_test_split_user_dependent(path,test_split[t],val_split[t],file_list,label)
    clf_list = []
    template = dict({})
    for i in range(num_class):
      template[i] = []
    clf = DTW_clf(template=template, num_class=num_class)
    clf.train(X_train, y_train)
    for j in range(X_val.shape[0]):
      val_pred.append(clf.predict(X_val[j,:]))
      val_label.append(y_val[j])
    for j in range(X_test.shape[0]):
      test_pred.append(clf.predict(X_test[j,:]))
      test_label.append(y_test[j])
  val_pred = np.array(val_pred)
  val_label = np.array(val_label)
  test_pred = np.array(test_pred)
  test_label = np.array(test_label)
  val_f1_micro,val_f1_macro,val_precision_micro,val_precision_macro,val_recall_micro,val_recall_macro,val_acc = evaluate(val_label,val_pred)
  test_f1_micro,test_f1_macro,test_precision_micro,test_precision_macro,test_recall_micro,test_recall_macro,test_acc = evaluate(test_label,test_pred)
  print("patient id:", p)
  print("validation f1 micro:", val_f1_micro)
  print("validation f1 macro:", val_f1_macro)
  print("validation precision micro", val_precision_micro)
  print("validation precision macro", val_precision_macro)
  print("validation recall micro", val_recall_micro)
  print("validation recall macro", val_recall_macro)
  print("validation accuracy", val_acc)
  print("test f1 micro:", test_f1_micro)
  print("test f1 macro:", test_f1_macro)
  print("test precision micro", test_precision_micro)
  print("test precision macro", test_precision_macro)
  print("test recall micro", test_recall_micro)
  print("test recall macro", test_recall_macro)
  print("test accuracy", test_acc)
  with open(str("DTW__sub"+p+".pck"), "wb") as output_file:
    pickle.dump(clf, output_file)

patient id: 001
validation f1 micro: 0.7166666666666667
validation f1 macro: 0.6851343101343103
validation precision micro 0.7166666666666667
validation precision macro 0.7273478835978836
validation recall micro 0.7166666666666667
validation recall macro 0.7166666666666667
validation accuracy 0.7166666666666667
test f1 micro: 0.775
test f1 macro: 0.7469248158102649
test precision micro 0.775
test precision macro 0.809352453102453
test recall micro 0.775
test recall macro 0.775
test accuracy 0.775
patient id: 002
validation f1 micro: 0.7166666666666667
validation f1 macro: 0.6883523883523885
validation precision micro 0.7166666666666667
validation precision macro 0.8011574074074076
validation recall micro 0.7166666666666667
validation recall macro 0.7166666666666667
validation accuracy 0.7166666666666667
test f1 micro: 0.7166666666666667
test f1 macro: 0.6826206402293359
test precision micro 0.7166666666666667
test precision macro 0.7720425407925409
test recall micro 0.7166666666666667


  _warn_prf(average, modifier, msg_start, len(result))


patient id: 003
validation f1 micro: 0.5833333333333334
validation f1 macro: 0.5371933621933621
validation precision micro 0.5833333333333334
validation precision macro 0.5202380952380952
validation recall micro 0.5833333333333334
validation recall macro 0.5833333333333334
validation accuracy 0.5833333333333334
test f1 micro: 0.6
test f1 macro: 0.5551165858390922
test precision micro 0.6
test precision macro 0.5506507381507382
test recall micro 0.6
test recall macro 0.6000000000000001
test accuracy 0.6


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


patient id: 004
validation f1 micro: 0.5
validation f1 macro: 0.42344599844599845
validation precision micro 0.5
validation precision macro 0.39652777777777776
validation recall micro 0.5
validation recall macro 0.5
validation accuracy 0.5
test f1 micro: 0.4166666666666667
test f1 macro: 0.3398363929898227
test precision micro 0.4166666666666667
test precision macro 0.3327892116317164
test recall micro 0.4166666666666667
test recall macro 0.4166666666666667
test accuracy 0.4166666666666667
patient id: 005
validation f1 micro: 0.6833333333333333
validation f1 macro: 0.6727906352906352
validation precision micro 0.6833333333333333
validation precision macro 0.7552579365079364
validation recall micro 0.6833333333333333
validation recall macro 0.6833333333333332
validation accuracy 0.6833333333333333
test f1 micro: 0.6890756302521008
test f1 macro: 0.6844776434159375
test precision micro 0.6890756302521008
test precision macro 0.7674680704285967
test recall micro 0.6890756302521008
test re

# User Independent

In [10]:
# self identified test_split
def my_train_test_split_user_independent(test_patient,val_patient,train_patient):
    X_test = []
    X_train = []
    X_val = []
    y_val = []
    y_train = []
    y_test = []
    for p in train_patient:
        path = str("drive/MyDrive/Colab_Notebooks/EOG_data/isolated/"+p+"/isolated_strokes/")
        label,file_list = get_file_list(path)
        for i in range(len(file_list)):
            file_p = file_list[i]
            file_label = label[i]
            feature = get_feature(str(path+file_p))
            X_train.append(feature)
            y_train.append(file_label)
        
    path = str("drive/MyDrive/Colab_Notebooks/EOG_data/isolated/"+test_patient+"/isolated_strokes/")
    label,file_list = get_file_list(path)
    for i in range(len(file_list)):
        file_p = file_list[i]
        file_label = label[i]
        feature = get_feature(str(path+file_p))
        X_test.append(feature)
        y_test.append(file_label)
    
    path = str("drive/MyDrive/Colab_Notebooks/EOG_data/isolated/"+val_patient+"/isolated_strokes/")
    label,file_list = get_file_list(path)
    for i in range(len(file_list)):
        file_p = file_list[i]
        file_label = label[i]
        feature = get_feature(str(path+file_p))
        X_val.append(feature)
        y_val.append(file_label)

    X_train = np.array(X_train)
    X_test = np.array(X_test)
    y_train = np.array(y_train)
    y_test = np.array(y_test)
    X_val = np.array(X_val)
    y_val = np.array(y_val)
    return X_train,X_test, X_val,y_val,y_train,y_test

In [16]:
test_patien_list = ["001","002","003","004","005","006"]
val_patient_list = ["002","003","004","005","006","001"]
val_pred_all = []
val_label_all = []
test_pred_all = []
test_label_all = []

for t in range(len(test_patien_list)):
  train_patient= ["001","002","003","004","005","006"]
  test_patient = test_patien_list[t]
  val_patient = val_patient_list[t]
  train_patient.remove(test_patient)
  train_patient.remove(val_patient)
  X_train,X_test, X_val,y_val,y_train,y_test = my_train_test_split_user_independent(test_patient,val_patient,train_patient)

  val_pred = []
  val_label = []
  test_pred = []
  test_label = []
  num_class = 12
  template = dict({})
  for i in range(num_class):
    template[i] = []
  clf = DTW_clf(template=template, num_class=num_class)
  clf.train(X_train, y_train)
  for j in range(X_val.shape[0]):
    val_pred.append(clf.predict(X_val[j,:]))
    val_label.append(y_val[j])
  for j in range(X_test.shape[0]):
    test_pred.append(clf.predict(X_test[j,:]))
    test_label.append(y_test[j])
  val_pred_all.extend(val_pred)
  val_label_all.extend(val_label)
  test_pred_all.extend(test_pred)
  test_label_all.extend(test_label)


val_pred_all = np.array(val_pred_all)
val_label_all = np.array(val_label_all)
test_pred_all = np.array(test_pred_all)
test_label_all = np.array(test_label_all)
val_f1_micro,val_f1_macro,val_precision_micro,val_precision_macro,val_recall_micro,val_recall_macro,val_acc = evaluate(val_label_all,val_pred_all)
test_f1_micro,test_f1_macro,test_precision_micro,test_precision_macro,test_recall_micro,test_recall_macro,test_acc = evaluate(test_label_all,test_pred_all)
print("validation f1 micro:", val_f1_micro)
print("validation f1 macro:", val_f1_macro)
print("validation precision micro", val_precision_micro)
print("validation precision macro", val_precision_macro)
print("validation recall micro", val_recall_micro)
print("validation recall macro", val_recall_macro)
print("validation accuracy", val_acc)
print("test f1 micro:", test_f1_micro)
print("test f1 macro:", test_f1_macro)
print("test precision micro", test_precision_micro)
print("test precision macro", test_precision_macro)
print("test recall micro", test_recall_micro)
print("test recall macro", test_recall_macro)
print("test accuracy", test_acc)
      

with open(str("DTW__subAll.pck"), "wb") as output_file:
    pickle.dump(clf, output_file)

validation f1 micro: 0.4544198895027624
validation f1 macro: 0.3860815739200112
validation precision micro 0.4544198895027624
validation precision macro 0.41313415134431924
validation recall micro 0.4544198895027624
validation recall macro 0.4527210490568368
validation accuracy 0.4544198895027624
test f1 micro: 0.45994475138121543
test f1 macro: 0.3900769155530128
test precision micro 0.4599447513812155
test precision macro 0.4292798089342084
test recall micro 0.4599447513812155
test recall macro 0.45834491062332133
test accuracy 0.4599447513812155
