In [1]:
import logging
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
SUBJECT_IDS = [f'S{i}' for i in range(2, 12)] + [f'S{i}' for i in range(13, 18)]  # S2 to S11 and S13 to S17
VAL_RATIO = 0.15
TEST_RATIO = 0.15
SAMPLING_RATE = 4 # 4Hz
SAMPLING_RATES = {
    'chest': {
        'ACC': 700,
        'ECG': 700,
        'EMG': 700,
        'EDA': 700,
        'Resp': 700,
        'Temp': 700
    },
    'wrist': {
        'ACC': 32,
        'BVP': 64,
        'EDA': 4,
        'TEMP': 4
    },
    'label': 700
}

In [3]:
def load_data(path: str, subject_id: str, sampling_rate: int = 4):
    data_path = Path(path) / subject_id / f'{subject_id}.pkl'
    with open(data_path, 'rb') as file:
        data = pickle.load(file, encoding='latin1')
    
    data_chest = data['signal']['chest']
    data_wrist = data['signal']['wrist']
    data_label = data['label']
    
    for key in data_chest.keys():
        data_chest[key] = data_chest[key][::SAMPLING_RATES['chest'][key] // sampling_rate]
    for key in data_wrist.keys():
        data_wrist[key] = data_wrist[key][::SAMPLING_RATES['wrist'][key] // sampling_rate]
    data_label = data_label[::SAMPLING_RATES['label'] // sampling_rate]
    
    acc_chest = pd.DataFrame(data_chest['ACC'], columns=['ACC_x_chest', 'ACC_y_chest', 'ACC_z_chest'])
    ecg_chest = pd.DataFrame(data_chest['ECG'], columns=['ECG_chest'])
    emg_chest = pd.DataFrame(data_chest['EMG'], columns=['EMG_chest'])
    eda_chest = pd.DataFrame(data_chest['EDA'], columns=['EDA_chest'])
    resp_chest = pd.DataFrame(data_chest['Resp'], columns=['Resp_chest'])
    temp_chest = pd.DataFrame(data_chest['Temp'], columns=['Temp_chest'])

    acc_wrist = pd.DataFrame(data_wrist['ACC'], columns=['ACC_x_wrist', 'ACC_y_wrist', 'ACC_z_wrist'])
    bvp_wrist = pd.DataFrame(data_wrist['BVP'], columns=['BVP_wrist'])
    eda_wrist = pd.DataFrame(data_wrist['EDA'], columns=['EDA_wrist'])
    temp_wrist = pd.DataFrame(data_wrist['TEMP'], columns=['TEMP_wrist'])

    label = pd.DataFrame(data_label, columns=['label'])
    
    # Merge all dataframes on their time indices
    df = pd.concat([acc_chest, ecg_chest, emg_chest, eda_chest, resp_chest, temp_chest,
                    acc_wrist, bvp_wrist, eda_wrist, temp_wrist, label], axis=1, join='outer')
    
    # Filter labels to include only 1 (baseline), 2 (stress), and 3 (amusement)
    df = df[df['label'].isin([1, 2, 3])]
    
    # Reset index to have a clean integer index
    df.reset_index(drop=True, inplace=True)
    
    df = df.assign(
        subject_id=subject_id,
        label = df['label'].map({1: 0, 2: 1, 3: 0}) # Binary classification: 0 (non-stress), 1 (stress)
    )

    return df

In [4]:
def load_all_data(path: str, subject_ids: list, sampling_rate: int = 64):
    all_data = {}
    for subject_id in tqdm(subject_ids):
        df_subject = load_data(path, subject_id, sampling_rate)
        all_data[subject_id] = df_subject
    
    logger.info('All data loaded')
    
    return all_data

In [5]:
def standardize_data(data, mean=None, std=None):
    if mean is None:
        mean = data.mean()
    if std is None:
        std = data.std()
    data_normalized = (data - mean) / std
    return data_normalized

In [6]:
def data_split(all_data, val_ratio: float, test_ratio: float):
    all_data_list = list(all_data.values())
    n_subjects = len(all_data)
    n_test = int(n_subjects * test_ratio)
    n_val = int(n_subjects * val_ratio)
    n_train = n_subjects - n_val - n_test

    df_train = pd.concat(all_data_list[:n_train], ignore_index=True, axis=0)
    df_val = pd.concat(all_data_list[n_train:n_train+n_val], ignore_index=True, axis=0)
    df_test = pd.concat(all_data_list[n_train+n_val:], ignore_index=True, axis=0)

    df_train = df_train[df_train['label'] == 0].reset_index(drop=True) # Use only non-stress data for training 
    df_val = df_val[df_val['label'] == 0].reset_index(drop=True) # Use only non-stress data for validation
    df_test = df_test.reset_index(drop=True) # Use all data for testing
    
    logger.info('Data split into train and test sets')
    logger.info('Train data shape: %s', df_train.shape)
    logger.info('Validation data shape: %s', df_val.shape)
    logger.info('Test data shape: %s', df_test.shape)
    
    return df_train, df_val, df_test

In [7]:
all_data = load_all_data('WESAD', SUBJECT_IDS, SAMPLING_RATE)
df_train, df_val, df_test = data_split(all_data, VAL_RATIO, TEST_RATIO)

100%|██████████| 15/15 [00:31<00:00,  2.10s/it]
INFO:__main__:All data loaded
INFO:__main__:Data split into train and test sets
INFO:__main__:Train data shape: (67944, 16)
INFO:__main__:Validation data shape: (12396, 16)
INFO:__main__:Test data shape: (17988, 16)


In [8]:
Path('data').mkdir(parents=True, exist_ok=True)
df_train.to_csv('data/train.csv', index=False)
df_val.to_csv('data/val.csv', index=False)
df_test.to_csv('data/test.csv', index=False)

In [9]:
df_train.head()

Unnamed: 0,ACC_x_chest,ACC_y_chest,ACC_z_chest,ECG_chest,EMG_chest,EDA_chest,Resp_chest,Temp_chest,ACC_x_wrist,ACC_y_wrist,ACC_z_wrist,BVP_wrist,EDA_wrist,TEMP_wrist,label,subject_id
0,0.8826,-0.1138,-0.231,0.036484,-0.025589,5.679321,-0.52948,29.119537,65.0,12.0,21.0,-33.58,1.640539,35.81,0,S2
1,0.8878,-0.1122,-0.2758,0.009201,0.001144,5.642319,-2.055359,29.123871,55.0,-32.0,44.0,30.52,1.634132,35.81,0,S2
2,0.8918,-0.1166,-0.271,-0.162506,0.011765,5.607224,-2.735901,29.119537,49.0,-21.0,30.0,52.54,1.614912,35.81,0,S2
3,0.879,-0.0914,-0.2654,0.014969,0.001053,5.566788,-1.780701,29.126709,52.0,-22.0,27.0,94.31,1.591848,35.81,0,S2
4,0.893,-0.1208,-0.2402,0.061707,0.013321,5.531693,0.511169,29.131042,58.0,-14.0,28.0,-38.86,1.558534,35.81,0,S2


In [10]:
df_train.label.value_counts()

label
0    67944
Name: count, dtype: int64

In [11]:
df_val.label.value_counts()

label
0    12396
Name: count, dtype: int64

In [12]:
df_test.label.value_counts()

label
0    12404
1     5584
Name: count, dtype: int64