In [6]:
import os
import pickle
import scipy.signal
from scipy import fft
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import classification_report

In [3]:
DATA_PATH = os.path.realpath("../data/WESAD")

In [4]:
class Subject:

    def __init__(self, main_path, subject_number):
        self.name = f'S{subject_number}'
        self.subject_keys = ['signal', 'label', 'subject']
        self.signal_keys = ['chest', 'wrist']
        self.chest_keys = ['ACC', 'ECG', 'EMG', 'EDA', 'Temp', 'Resp']
        self.wrist_keys = ['ACC', 'BVP', 'EDA', 'TEMP']
        with open(os.path.join(main_path, self.name) + '/' + self.name + '.pkl', 'rb') as file:
            self.data = pickle.load(file, encoding='latin1')
        self.labels = self.data['label']

    def get_wrist_data(self):
        data = self.data['signal']['wrist']
        return data

    def get_chest_data(self):
        return self.data['signal']['chest']
    
    def get_subject_dataframe(self):
        wrist_data = self.get_wrist_data()
        bvp_signal = wrist_data['BVP'][:,0]
        eda_signal = wrist_data['EDA'][:,0]
        acc_x_signal = wrist_data['ACC'][:,0]
        acc_y_signal = wrist_data['ACC'][:,1]
        acc_z_signal = wrist_data['ACC'][:,2]
        temp_signal = wrist_data['TEMP'][:,0]
        # Upsampling data to match BVP data sampling rate using fourier method as described in Paper/dataset
        eda_upsampled = scipy.signal.resample(eda_signal, len(bvp_signal))
        temp_upsampled = scipy.signal.resample(temp_signal, len(bvp_signal))
        acc_x_upsampled = scipy.signal.resample(acc_x_signal, len(bvp_signal))
        acc_y_upsampled = scipy.signal.resample(acc_y_signal, len(bvp_signal))
        acc_z_upsampled = scipy.signal.resample(acc_z_signal, len(bvp_signal))
        label_df = pd.DataFrame(self.labels, columns=['label'])
        label_df.index = [(1 / 700) * i for i in range(len(label_df))] # 700 is the sampling rate of the label
        label_df.index = pd.to_datetime(label_df.index, unit='s')
        data_arrays = zip(bvp_signal, eda_upsampled, acc_x_upsampled, acc_y_upsampled, acc_z_upsampled, temp_upsampled)
        df = pd.DataFrame(data=data_arrays, columns=['BVP', 'EDA', 'ACC_x', 'ACC_y', 'ACC_z', 'TEMP'])
        df.index = [(1 / 64) * i for i in range(len(df))] # 64 = sampling rate of BVP
        df.index = pd.to_datetime(df.index, unit='s')
        df = df.join(label_df)
        df['label'] = df['label'].fillna(method='ffill')
        df.reset_index(drop=True, inplace=True)
        # df.drop(df[df['label'].isin([0.0, 4.0, 5.0, 6.0, 7.0])].index, inplace=True)
        # df['label'] = df['label'].replace([1.0, 2.0, 3.0], [0, 1, 0])
        df.reset_index(drop=True, inplace=True)
        return df


In [5]:
subject_ids = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17] # ids for subjects in WESAD dataset
for subject_id in subject_ids:
    print(f'SUBJECT {subject_id}')
    df_subject = Subject(DATA_PATH, subject_id).get_subject_dataframe()
    y = df_subject['label']
    df_subject.drop('label', axis=1, inplace=True)
    normalized_x=(df_subject-df_subject.min())/(df_subject.max()-df_subject.min())
    x_train, x_test, y_train, y_test=train_test_split(df_subject,y,test_size=0.2)
    norm_x_train,norm_x_test,norm_y_train,norm_y_test=train_test_split(normalized_x,y,test_size=0.2)
    LDA= LinearDiscriminantAnalysis(solver = 'svd')
    y_out = LDA.fit(x_train, y_train).predict(x_test)
    #confusion_matrix(y_test, y_out)
    print(f'{classification_report(y_test, y_out, digits=4)}\n')

SUBJECT 2
              precision    recall  f1-score   support

         0.0     0.8714    0.7116    0.7834     39124
         1.0     0.7913    0.9129    0.8477     14547
         2.0     0.7407    0.9726    0.8409      7852
         3.0     0.6764    0.3767    0.4839      4667
         4.0     0.6320    0.9906    0.7717     10059
         6.0     0.1881    0.0933    0.1247       815
         7.0     0.0000    0.0000    0.0000       748

    accuracy                         0.7782     77812
   macro avg     0.5571    0.5797    0.5503     77812
weighted avg     0.7851    0.7782    0.7673     77812


SUBJECT 3


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.8018    0.7081    0.7520     42967
         1.0     0.7096    0.9823    0.8240     14604
         2.0     0.6242    0.8473    0.7189      8208
         3.0     0.4611    0.2865    0.3534      4795
         4.0     0.8048    0.8783    0.8399      9916
         5.0     0.0000    0.0000    0.0000       929
         6.0     0.0000    0.0000    0.0000       854
         7.0     0.0000    0.0000    0.0000       838

    accuracy                         0.7437     83111
   macro avg     0.4252    0.4628    0.4360     83111
weighted avg     0.7235    0.7437    0.7252     83111


SUBJECT 4


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.7885    0.8727    0.8285     42282
         1.0     0.8897    0.9237    0.9064     14703
         2.0     0.8775    1.0000    0.9347      8108
         3.0     0.6612    0.7367    0.6969      4885
         4.0     0.8278    0.4282    0.5644     10393
         5.0     0.0000    0.0000    0.0000       659
         6.0     0.0568    0.0093    0.0159       539
         7.0     0.1429    0.0015    0.0031       646

    accuracy                         0.8106     82215
   macro avg     0.5305    0.4965    0.4937     82215
weighted avg     0.7966    0.8106    0.7932     82215


SUBJECT 5
              precision    recall  f1-score   support

         0.0     0.8245    0.7536    0.7875     39143
         1.0     0.7185    0.8641    0.7846     15398
         2.0     0.7655    0.9206    0.8359      8249
         3.0     0.6789    0.9333    0.7860      4798
         4.0     0.6620    0.6005    0.6298     10148
         5.0 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.7559    0.6959    0.7247     29626
         1.0     0.8457    0.9109    0.8771     14954
         2.0     0.7310    0.9139    0.8123      8422
         3.0     0.5716    0.9263    0.7070      4752
         4.0     0.6969    0.5258    0.5994     10222
         5.0     0.0000    0.0000    0.0000       641
         6.0     0.0000    0.0000    0.0000       667
         7.0     0.0000    0.0000    0.0000       681

    accuracy                         0.7391     69965
   macro avg     0.4501    0.4966    0.4650     69965
weighted avg     0.7295    0.7391    0.7277     69965


SUBJECT 9
              precision    recall  f1-score   support

         0.0     0.9284    0.4663    0.6208     26367
         1.0     0.8493    0.9920    0.9151     15035
         2.0     0.8192    0.9974    0.8996      8171
         3.0     0.6909    0.9987    0.8167      4670
         4.0     0.6946    0.9918    0.8170     10221
         5.0 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.7348    0.6426    0.6856     30566
         1.0     0.5045    0.5378    0.5206     15160
         2.0     0.9090    0.9894    0.9475      8461
         3.0     0.2684    0.0804    0.1237      4951
         4.0     0.5944    0.9932    0.7437     10241
         5.0     0.0000    0.0000    0.0000       610
         6.0     0.0000    0.0000    0.0000       640
         7.0     0.0000    0.0000    0.0000       245

    accuracy                         0.6594     70874
   macro avg     0.3764    0.4054    0.3776     70874
weighted avg     0.6380    0.6594    0.6363     70874


SUBJECT 14
              precision    recall  f1-score   support

         0.0     0.6718    0.5536    0.6070     30396
         1.0     0.7356    0.9399    0.8253     15028
         2.0     0.5739    0.6243    0.5980      8529
         3.0     0.5816    0.9983    0.7350      4807
         4.0     0.5744    0.5043    0.5370     10116
         5.0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.8687    0.6755    0.7600     26857
         1.0     0.8367    0.9668    0.8971     14895
         2.0     0.7402    0.9166    0.8190      8791
         3.0     0.6548    0.9931    0.7892      4779
         4.0     0.8454    0.8542    0.8498     10145
         5.0     0.0000    0.0000    0.0000       607
         6.0     0.2767    0.3578    0.3120       573
         7.0     0.0000    0.0000    0.0000       579

    accuracy                         0.8065     67226
   macro avg     0.5278    0.5955    0.5534     67226
weighted avg     0.8057    0.8065    0.7965     67226


SUBJECT 16
              precision    recall  f1-score   support

         0.0     0.8242    0.7349    0.7770     31346
         1.0     0.8661    0.8193    0.8421     15042
         2.0     0.8734    0.9986    0.9318      8639
         3.0     0.6532    0.9771    0.7829      4707
         4.0     0.7845    0.9358    0.8535     10249
         5.0

In [46]:
LDA= LinearDiscriminantAnalysis(solver = 'svd')
y_out = LDA.fit(norm_x_train, norm_y_train).predict(norm_x_test)
#confusion_matrix(y_test, y_out)
print(classification_report(norm_y_test, y_out, digits=4))#target_names=['Not Six', 'Six']

              precision    recall  f1-score   support

         0.0     0.8702    0.7069    0.7801     39252
         1.0     0.7888    0.9143    0.8469     14605
         2.0     0.7351    0.9721    0.8371      7839
         3.0     0.6720    0.3787    0.4844      4632
         4.0     0.6237    0.9893    0.7651      9821
         6.0     0.2091    0.1112    0.1452       827
         7.0     0.0000    0.0000    0.0000       836

    accuracy                         0.7748     77812
   macro avg     0.5570    0.5818    0.5513     77812
weighted avg     0.7820    0.7748    0.7638     77812

