In [27]:
import scipy
import logging
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [28]:
SUBJECT_IDS = [f'S{i}' for i in range(2, 12)] + [f'S{i}' for i in range(13, 18)]  # S2 to S11 and S13 to S17
TEST_SUBJECTS = 5
SAMPLING_RATE = 4 # 4Hz
SAMPLING_RATES = {
    'chest': {
        'ACC': 700,
        'ECG': 700,
        'EMG': 700,
        'EDA': 700,
        'Resp': 700,
        'Temp': 700
    },
    'wrist': {
        'ACC': 32,
        'BVP': 64,
        'EDA': 4,
        'TEMP': 4
    },
    'label': 700
}

In [29]:
def resample_data(data: np.ndarray, original_rate: int, target_rate: int):
    data = data.squeeze()
    if original_rate > target_rate: # Sampling frequency is higher than target frequency
        # Padding is applied to cut out oscillations caused by low-pass filtering
        pad_amount = 1000
        data = np.pad(data, (pad_amount, pad_amount), 'constant',
                                constant_values=(data[0], data[-1]))
        # Apply low pass filter
        b, a = scipy.signal.butter(1, target_rate // 2, fs=original_rate,btype='low', analog=False)
        filtered_data = scipy.signal.lfilter(b, a, data)
        # Trim padding form signal
        filtered_data = filtered_data[pad_amount:-pad_amount]
        # Resample signal
        data = filtered_data[::int(original_rate // target_rate)]
    elif original_rate < target_rate: # Sampling frequency is lower than target frequency
        # Create time array of original signal to interpolate
        time = np.round(np.arange(0, len(data)) / original_rate, 3)
        # Parametrise interpolation function
        f = scipy.interpolate.interp1d(time, data, kind='linear', fill_value='extrapolate')
        # Create time array of new signal to interpolate
        time_new = np.arange(0, time[-1], 1 / target_rate)
        # Interpolate
        data = f(time_new)
    else: # Original frequency is same as target frequency
        data = data
    return data

In [30]:
def load_data(path: str, subject_id: str, sampling_rate: int = 4):
    data_path = Path(path) / subject_id / f'{subject_id}.pkl'
    with open(data_path, 'rb') as file:
        data = pickle.load(file, encoding='latin1')
    
    data_chest = data['signal']['chest']
    data_wrist = data['signal']['wrist']
    data_label = data['label']
    
    for key in data_chest.keys():
        if key == 'ACC':
            # ACC has 3 channels, so we need to resample each channel separately
            temp = []
            for i in range(data_chest[key].shape[1]):
                temp.append(resample_data(data_chest[key][:, i], SAMPLING_RATES['chest'][key], SAMPLING_RATE))
            data_chest[key] = np.stack(temp, axis=1)
        else:
            data_chest[key] = resample_data(data_chest[key], SAMPLING_RATES['chest'][key], SAMPLING_RATE)
    for key in data_wrist.keys():
        if key == 'ACC':
            # ACC has 3 channels, so we need to resample each channel separately
            temp = []
            for i in range(data_wrist[key].shape[1]):
                temp.append(resample_data(data_wrist[key][:, i], SAMPLING_RATES['wrist'][key], SAMPLING_RATE))
            data_wrist[key] = np.stack(temp, axis=1)
        else:
            data_wrist[key] = resample_data(data_wrist[key], SAMPLING_RATES['wrist'][key], SAMPLING_RATE)
    
    if SAMPLING_RATES['label'] > sampling_rate:
        data_label = data_label[::SAMPLING_RATES['label'] // sampling_rate]
    elif SAMPLING_RATES['label'] < sampling_rate:
        time = np.round(np.arange(0, len(data_label)) / SAMPLING_RATES['label'], 3)
        f = scipy.interpolate.interp1d(time, data_label, kind='nearest', fill_value='extrapolate')
        time_new = np.arange(0, time[-1], 1 / sampling_rate)
        data_label = f(time_new)
    
    acc_chest = pd.DataFrame(data_chest['ACC'], columns=['ACC_x_chest', 'ACC_y_chest', 'ACC_z_chest'])
    ecg_chest = pd.DataFrame(data_chest['ECG'], columns=['ECG_chest'])
    emg_chest = pd.DataFrame(data_chest['EMG'], columns=['EMG_chest'])
    eda_chest = pd.DataFrame(data_chest['EDA'], columns=['EDA_chest'])
    resp_chest = pd.DataFrame(data_chest['Resp'], columns=['Resp_chest'])
    temp_chest = pd.DataFrame(data_chest['Temp'], columns=['Temp_chest'])

    acc_wrist = pd.DataFrame(data_wrist['ACC'], columns=['ACC_x_wrist', 'ACC_y_wrist', 'ACC_z_wrist'])
    bvp_wrist = pd.DataFrame(data_wrist['BVP'], columns=['BVP_wrist'])
    eda_wrist = pd.DataFrame(data_wrist['EDA'], columns=['EDA_wrist'])
    temp_wrist = pd.DataFrame(data_wrist['TEMP'], columns=['TEMP_wrist'])

    label = pd.DataFrame(data_label, columns=['label'])
    
    # Merge all dataframes on their time indices
    df = pd.concat([acc_chest, ecg_chest, emg_chest, eda_chest, resp_chest, temp_chest,
                    acc_wrist, bvp_wrist, eda_wrist, temp_wrist, label], axis=1, join='outer')
    
    # Filter labels to include only 1 (baseline), 2 (stress), and 3 (amusement)
    df = df[df['label'].isin([1, 2, 3])]
    
    # Reset index to have a clean integer index
    df.reset_index(drop=True, inplace=True)
    
    df = df.assign(
        subject_id=subject_id,
        label = df['label'].map({1: 0, 2: 1, 3: 0}) # Binary classification: 0 (non-stress), 1 (stress)
    )

    return df

In [31]:
def load_all_data(path: str, subject_ids: list, sampling_rate: int = 64):
    all_data = {}
    for subject_id in tqdm(subject_ids):
        df_subject = load_data(path, subject_id, sampling_rate)
        all_data[subject_id] = df_subject
    
    logger.info('All data loaded')
    
    return all_data

In [32]:
def standardize_data(data, mean=None, std=None):
    if mean is None:
        mean = data.mean()
    if std is None:
        std = data.std()
    data_normalized = (data - mean) / std
    return data_normalized

In [33]:
def data_split(all_data, test_subjects: int):
    all_data_list = list(all_data.values())
    n_subjects = len(all_data)
    n_test = test_subjects
    
    n_train = n_subjects - n_test

    df_train = pd.concat(all_data_list[:n_train], ignore_index=True, axis=0)
    df_test = pd.concat(all_data_list[n_train:], ignore_index=True, axis=0)

    df_train_with_anomaly = df_train.reset_index(drop=True)
    df_train = df_train[df_train['label'] == 0].reset_index(drop=True) # Use only non-stress data for training 
    df_test = df_test.reset_index(drop=True) # Use all data for testing
    
    logger.info('Data split into train and test sets')
    logger.info('Train data shape: %s', df_train.shape)
    logger.info('Test data shape: %s', df_test.shape)

    return df_train, df_train_with_anomaly, df_test

In [34]:
all_data = load_all_data('WESAD', SUBJECT_IDS, SAMPLING_RATE)
df_train, df_train_with_anomaly, df_test = data_split(all_data, TEST_SUBJECTS)

  0%|          | 0/15 [00:00<?, ?it/s]

100%|██████████| 15/15 [00:32<00:00,  2.17s/it]
INFO:__main__:All data loaded
INFO:__main__:Data split into train and test sets
INFO:__main__:Train data shape: (61696, 16)
INFO:__main__:Test data shape: (44732, 16)


In [40]:
Path('data').mkdir(parents=True, exist_ok=True)
df_train.to_csv('data/train.csv', index=False)
df_train_with_anomaly.to_csv('data/train_with_anomaly.csv', index=False)
df_test.to_csv('data/test.csv', index=False)

In [41]:
df_train.head()

Unnamed: 0,ACC_x_chest,ACC_y_chest,ACC_z_chest,ECG_chest,EMG_chest,EDA_chest,Resp_chest,Temp_chest,ACC_x_wrist,ACC_y_wrist,ACC_z_wrist,BVP_wrist,EDA_wrist,TEMP_wrist,label,subject_id
0,0.884488,-0.10873,-0.242231,0.058284,-0.003539,5.69968,0.276868,29.128135,34.943832,21.163192,35.383769,-67.118687,1.640539,35.81,0,S2
1,0.881354,-0.110234,-0.269737,0.082774,-0.001684,5.662452,-1.569285,29.126827,52.382129,-23.997765,43.102392,36.422618,1.634132,35.81,0,S2
2,0.880647,-0.11591,-0.261053,-0.113248,-0.002079,5.627947,-2.616351,29.13068,54.24667,-23.057112,29.790993,49.543609,1.614912,35.81,0,S2
3,0.881459,-0.112778,-0.263934,0.013752,-0.004435,5.585941,-2.196373,29.134117,52.306091,-21.094121,28.189838,67.064383,1.591848,35.81,0,S2
4,0.888935,-0.132658,-0.239563,0.057006,-0.002548,5.548919,-0.262218,29.134844,54.790854,-17.086637,29.087397,-42.069433,1.558534,35.81,0,S2


In [42]:
df_train.label.value_counts()

label
0    61696
Name: count, dtype: int64

In [43]:
df_train_with_anomaly.label.value_counts()

label
0    61696
1    26180
Name: count, dtype: int64

In [44]:
df_test.label.value_counts()

label
0    31048
1    13684
Name: count, dtype: int64