In [33]:
import os
import pickle
import scipy.signal
from scipy import fft
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import classification_report

In [34]:
DATA_PATH = os.path.realpath("../data/WESAD")

In [38]:
class Subject:

    def __init__(self, main_path, subject_number):
        self.name = f'S{subject_number}'
        self.subject_keys = ['signal', 'label', 'subject']
        self.signal_keys = ['chest', 'wrist']
        self.chest_keys = ['ACC', 'ECG', 'EMG', 'EDA', 'Temp', 'Resp']
        self.wrist_keys = ['ACC', 'BVP', 'EDA', 'TEMP']
        with open(os.path.join(main_path, self.name) + '/' + self.name + '.pkl', 'rb') as file:
            self.data = pickle.load(file, encoding='latin1')
        self.labels = self.data['label']

    def get_wrist_data(self):
        data = self.data['signal']['wrist']
        return data

    def get_chest_data(self):
        return self.data['signal']['chest']
    
    def get_subject_dataframe(self):
        wrist_data = self.get_wrist_data()
        bvp_signal = wrist_data['BVP'][:,0]
        eda_signal = wrist_data['EDA'][:,0]
        acc_x_signal = wrist_data['ACC'][:,0]
        acc_y_signal = wrist_data['ACC'][:,1]
        acc_z_signal = wrist_data['ACC'][:,2]
        temp_signal = wrist_data['TEMP'][:,0]
        # Upsampling data to match BVP data sampling rate using fourier method as described in Paper/dataset
        eda_upsampled = scipy.signal.resample(eda_signal, len(bvp_signal))
        temp_upsampled = scipy.signal.resample(temp_signal, len(bvp_signal))
        acc_x_upsampled = scipy.signal.resample(acc_x_signal, len(bvp_signal))
        acc_y_upsampled = scipy.signal.resample(acc_y_signal, len(bvp_signal))
        acc_z_upsampled = scipy.signal.resample(acc_z_signal, len(bvp_signal))
        label_df = pd.DataFrame(self.labels, columns=['label'])
        label_df.index = [(1 / 700) * i for i in range(len(label_df))] # 700 is the sampling rate of the label
        label_df.index = pd.to_datetime(label_df.index, unit='s')
        data_arrays = zip(bvp_signal, eda_upsampled, acc_x_upsampled, acc_y_upsampled, acc_z_upsampled, temp_upsampled)
        df = pd.DataFrame(data=data_arrays, columns=['BVP', 'EDA', 'ACC_x', 'ACC_y', 'ACC_z', 'TEMP'])
        df.index = [(1 / 64) * i for i in range(len(df))] # 64 = sampling rate of BVP
        df.index = pd.to_datetime(df.index, unit='s')
        df = df.join(label_df)
        df['label'] = df['label'].fillna(method='ffill')
        df.reset_index(drop=True, inplace=True)
        # df.drop(df[df['label'].isin([0.0, 4.0, 5.0, 6.0, 7.0])].index, inplace=True)
        # df['label'] = df['label'].replace([1.0, 2.0, 3.0], [0, 1, 0])
        df.reset_index(drop=True, inplace=True)
        return df


In [41]:
subject_ids = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17] # ids for subjects in WESAD dataset
for subject_id in subject_ids:
    print(f'SUBJECT {subject_id}')
    df_subject = Subject(DATA_PATH, subject_id).get_subject_dataframe()
    y = df_subject['label']
    df_subject.drop('label', axis=1, inplace=True)
    normalized_x=(df_subject-df_subject.min())/(df_subject.max()-df_subject.min())
    #x_train, x_test, y_train, y_test=train_test_split(df_subject,y,test_size=0.2)
    X_train, x_test, y_train, y_test=train_test_split(normalized_x,y,test_size=0.2)
    LDA= LinearDiscriminantAnalysis(solver = 'svd')
    y_out = LDA.fit(X_train, y_train).predict(x_test)
    #confusion_matrix(y_test, y_out)
    print(f'{classification_report(y_test, y_out, digits=4)}\n')

SUBJECT 2
              precision    recall  f1-score   support

         0.0     0.8707    0.7042    0.7787     39228
         1.0     0.7915    0.9111    0.8471     14716
         2.0     0.7360    0.9752    0.8389      7808
         3.0     0.6441    0.3591    0.4611      4662
         4.0     0.6181    0.9925    0.7617      9808
         6.0     0.1875    0.1050    0.1346       800
         7.0     0.0000    0.0000    0.0000       790

    accuracy                         0.7729     77812
   macro avg     0.5497    0.5781    0.5460     77812
weighted avg     0.7809    0.7729    0.7620     77812


SUBJECT 3
              precision    recall  f1-score   support

         0.0     0.7995    0.7072    0.7506     42950
         1.0     0.7086    0.9809    0.8228     14561
         2.0     0.6216    0.8450    0.7163      8123
         3.0     0.4677    0.2944    0.3613      4892
         4.0     0.8059    0.8764    0.8397      9960
         5.0     0.0000    0.0000    0.0000       927
   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.7918    0.8700    0.8291     42333
         1.0     0.8956    0.9268    0.9109     14775
         2.0     0.8738    1.0000    0.9327      8263
         3.0     0.6393    0.7469    0.6889      4667
         4.0     0.8325    0.4374    0.5735     10284
         5.0     0.0000    0.0000    0.0000       675
         6.0     0.0435    0.0070    0.0121       568
         7.0     0.0000    0.0000    0.0000       650

    accuracy                         0.8122     82215
   macro avg     0.5096    0.4985    0.4934     82215
weighted avg     0.7972    0.8122    0.7953     82215


SUBJECT 5
              precision    recall  f1-score   support

         0.0     0.8272    0.7527    0.7882     39470
         1.0     0.7118    0.8616    0.7796     15349
         2.0     0.7673    0.9281    0.8401      8267
         3.0     0.6789    0.9364    0.7871      4698
         4.0     0.6525    0.5920    0.6208      9987
         5.0 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.7566    0.6957    0.7249     29652
         1.0     0.8434    0.9143    0.8774     14977
         2.0     0.7384    0.9156    0.8175      8529
         3.0     0.5752    0.9286    0.7104      4723
         4.0     0.7011    0.5218    0.5983     10172
         5.0     0.0000    0.0000    0.0000       616
         6.0     0.0000    0.0000    0.0000       647
         7.0     0.0000    0.0000    0.0000       649

    accuracy                         0.7407     69965
   macro avg     0.4518    0.4970    0.4661     69965
weighted avg     0.7320    0.7407    0.7296     69965


SUBJECT 9
              precision    recall  f1-score   support

         0.0     0.9260    0.4687    0.6224     26523
         1.0     0.8486    0.9917    0.9146     14985
         2.0     0.8197    0.9972    0.8998      8178
         3.0     0.6945    0.9987    0.8193      4783
         4.0     0.6881    0.9894    0.8117      9992
         5.0 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.7388    0.6410    0.6864     30665
         1.0     0.5012    0.5393    0.5196     15033
         2.0     0.9147    0.9873    0.9496      8560
         3.0     0.2580    0.0793    0.1213      4881
         4.0     0.5913    0.9929    0.7412     10236
         5.0     0.0000    0.0000    0.0000       653
         6.0     0.0000    0.0000    0.0000       596
         7.0     0.0000    0.0000    0.0000       250

    accuracy                         0.6598     70874
   macro avg     0.3755    0.4050    0.3773     70874
weighted avg     0.6396    0.6598    0.6373     70874


SUBJECT 14
              precision    recall  f1-score   support

         0.0     0.6679    0.5604    0.6094     30048
         1.0     0.7398    0.9371    0.8269     15222
         2.0     0.5880    0.6189    0.6031      8665
         3.0     0.5851    0.9986    0.7379      4835
         4.0     0.5682    0.4965    0.5300     10062
         5.0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.8635    0.6803    0.7610     26768
         1.0     0.8397    0.9632    0.8972     15140
         2.0     0.7388    0.9103    0.8156      8636
         3.0     0.6500    0.9928    0.7856      4732
         4.0     0.8435    0.8476    0.8456     10121
         5.0     0.0000    0.0000    0.0000       650
         6.0     0.2908    0.3522    0.3186       602
         7.0     0.0000    0.0000    0.0000       577

    accuracy                         0.8054     67226
   macro avg     0.5283    0.5933    0.5530     67226
weighted avg     0.8032    0.8054    0.7953     67226


SUBJECT 16


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0     0.8266    0.7431    0.7826     31578
         1.0     0.8664    0.8130    0.8389     15089
         2.0     0.8760    0.9981    0.9331      8526
         3.0     0.6545    0.9791    0.7846      4681
         4.0     0.7917    0.9345    0.8572     10154
         5.0     0.0000    0.0000    0.0000       656
         6.0     0.0000    0.0000    0.0000       694
         7.0     0.0000    0.0000    0.0000       699

    accuracy                         0.8091     72077
   macro avg     0.5019    0.5585    0.5245     72077
weighted avg     0.8012    0.8091    0.8006     72077


SUBJECT 17
              precision    recall  f1-score   support

         0.0     0.6252    0.8109    0.7060     34785
         1.0     0.8976    0.9254    0.9113     15143
         2.0     0.6698    0.4833    0.5615      9365
         3.0     0.6876    0.9781    0.8076      4850
         4.0     0.0549    0.0016    0.0031      9378
         5.0

In [42]:
y_out._binary_repr()

AttributeError: 'numpy.ndarray' object has no attribute '_binary_repr'