In [1]:
import scipy
import logging
import pickle
import numpy as np
import pandas as pd
import neurokit2 as nk
from pathlib import Path
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:
SUBJECT_IDS = [f'S{i}' for i in range(2, 12)] + [f'S{i}' for i in range(13, 18)]  # S2 to S11 and S13 to S17
TEST_SUBJECTS = 5
SAMPLING_RATE = 4
SAMPLING_RATES = {
    'chest': {
        'ACC': 700,
        'ECG': 700,
        'EMG': 700,
        'EDA': 700,
        'Resp': 700,
        'Temp': 700
    },
    'wrist': {
        'ACC': 32,
        'BVP': 64,
        'EDA': 4,
        'TEMP': 4
    },
    'label': 700
}

In [3]:
def ecg_process(ecg_signal, sampling_rate=1000, method="neurokit", **kwargs):
    # Sanitize and clean input
    ecg_signal = nk.signal.signal_sanitize(ecg_signal)
    ecg_cleaned = ecg_signal

    # Detect R-peaks
    instant_peaks, info = nk.ecg_peaks(
        ecg_cleaned=ecg_cleaned,
        sampling_rate=sampling_rate,
        method="rodrigues2021",
        correct_artifacts=True,
    )

    # Calculate heart rate
    rate = nk.signal.signal_rate(
        info, sampling_rate=sampling_rate, desired_length=len(ecg_cleaned)
    )

    # Assess signal quality
    quality = nk.ecg_quality(
        ecg_cleaned, rpeaks=info["ECG_R_Peaks"], sampling_rate=sampling_rate
    )

    # Merge signals in a DataFrame
    signals = pd.DataFrame(
        {
            "ECG_Raw": ecg_signal,
            "ECG_Clean": ecg_cleaned,
            "ECG_Rate": rate,
            "ECG_Quality": quality,
        }
    )

    # Delineate QRS complex
    delineate_signal, delineate_info = nk.ecg_delineate(
        ecg_cleaned=ecg_cleaned, rpeaks=info["ECG_R_Peaks"], sampling_rate=sampling_rate
    )
    info.update(delineate_info)  # Merge waves indices dict with info dict

    # Determine cardiac phases
    cardiac_phase = nk.ecg_phase(
        ecg_cleaned=ecg_cleaned,
        rpeaks=info["ECG_R_Peaks"],
        delineate_info=delineate_info,
    )

    # Add additional information to signals DataFrame
    signals = pd.concat(
        [signals, instant_peaks, delineate_signal, cardiac_phase], axis=1
    )

    # return signals DataFrame and R-peak locations
    return signals, info

In [4]:
def emg_process(emg_signal, sampling_rate=1000, report=None, **kwargs):
    # Sanitize input
    emg_signal = nk.signal.signal_sanitize(emg_signal)
    methods = nk.emg.emg_methods.emg_methods(sampling_rate=sampling_rate, **kwargs)

    # Clean signal
    emg_cleaned = emg_signal

    # Get amplitude
    amplitude = nk.emg_amplitude(emg_cleaned)

    # Get onsets, offsets, and periods of activity
    activity_signal, info = nk.emg_activation(
        emg_amplitude=amplitude,
        emg_cleaned=emg_cleaned,
        sampling_rate=sampling_rate,
        method=methods["method_activation"],
        **methods["kwargs_activation"]
    )
    info["sampling_rate"] = sampling_rate  # Add sampling rate in dict info

    # Prepare output
    signals = pd.DataFrame(
        {"EMG_Raw": emg_signal, "EMG_Clean": emg_cleaned, "EMG_Amplitude": amplitude}
    )

    signals = pd.concat([signals, activity_signal], axis=1)

    return signals, info


In [5]:
def eda_process(
    eda_signal, sampling_rate=1000, method="neurokit", report=None, **kwargs
):
    # Sanitize input
    eda_signal = nk.signal.signal_sanitize(eda_signal)
    methods = nk.eda.eda_methods.eda_methods(sampling_rate=sampling_rate, method=method, **kwargs)

    # Preprocess
    # Clean signal
    eda_cleaned = eda_signal
    if methods["method_phasic"] is None or methods["method_phasic"].lower() == "none":
        eda_decomposed = pd.DataFrame({"EDA_Phasic": eda_cleaned})
    else:
        eda_decomposed = nk.eda_phasic(
            eda_cleaned,
            sampling_rate=sampling_rate,
            method=methods["method_phasic"],
            **methods["kwargs_phasic"],
        )

    # Find peaks
    peak_signal, info = nk.eda_peaks(
        eda_decomposed["EDA_Phasic"].values,
        sampling_rate=sampling_rate,
        method=methods["method_peaks"],
        amplitude_min=0.1,
        **methods["kwargs_peaks"],
    )
    info["sampling_rate"] = sampling_rate  # Add sampling rate in dict info

    # Store
    signals = pd.DataFrame({"EDA_Raw": eda_signal, "EDA_Clean": eda_cleaned})

    signals = pd.concat([signals, eda_decomposed, peak_signal], axis=1)

    return signals, info

In [6]:
def rsp_process(
    rsp_signal,
    sampling_rate=1000,
    method="khodadad2018",
    method_rvt="harrison2021",
    report=None,
    **kwargs
):
    # Sanitize input
    rsp_signal = nk.misc.as_vector(rsp_signal)
    methods = nk.rsp.rsp_methods(
        sampling_rate=sampling_rate, method=method, method_rvt=method_rvt, **kwargs
    )

    # Clean signal
    rsp_cleaned = rsp_signal

    # Extract, fix and format peaks
    peak_signal, info = nk.rsp_peaks(
        rsp_cleaned,
        sampling_rate=sampling_rate,
        method=methods["method_peaks"],
        amplitude_min=0.3,
        **methods["kwargs_peaks"],
    )
    info["sampling_rate"] = sampling_rate  # Add sampling rate in dict info

    # Get additional parameters
    phase = nk.rsp_phase(peak_signal, desired_length=len(rsp_signal))
    amplitude = nk.rsp_amplitude(rsp_cleaned, peak_signal)
    rate = nk.signal.signal_rate(
        info["RSP_Troughs"], sampling_rate=sampling_rate, desired_length=len(rsp_signal)
    )
    symmetry = nk.rsp_symmetry(rsp_cleaned, peak_signal)
    rvt = nk.rsp_rvt(
        rsp_cleaned,
        method=methods["method_rvt"],
        sampling_rate=sampling_rate,
        silent=True,
    )

    # Prepare output
    signals = pd.DataFrame(
        {
            "RSP_Raw": rsp_signal,
            "RSP_Clean": rsp_cleaned,
            "RSP_Amplitude": amplitude,
            "RSP_Rate": rate,
            "RSP_RVT": rvt,
        }
    )
    signals = pd.concat([signals, phase, symmetry, peak_signal], axis=1)

    return signals, info


In [7]:
def ppg_process(
    ppg_signal, sampling_rate=1000, method="elgendi", method_quality="templatematch", report=None, **kwargs
):
    # Sanitize input
    ppg_signal = nk.misc.as_vector(ppg_signal)
    methods = nk.ppg.ppg_methods(sampling_rate=sampling_rate, method=method, method_quality=method_quality, **kwargs)

    # Clean signal
    ppg_cleaned = ppg_signal

    # Find peaks
    peaks_signal, info = nk.ppg_peaks(
        ppg_cleaned,
        sampling_rate=sampling_rate,
        method="bishop",
    )

    info["sampling_rate"] = sampling_rate  # Add sampling rate in dict info

    # Rate computation
    rate = nk.signal.signal_rate(
        info["PPG_Peaks"], sampling_rate=sampling_rate, desired_length=len(ppg_cleaned)
    )

    # Assess signal quality
    quality = nk.ppg_quality(
        ppg_cleaned,
        peaks=info["PPG_Peaks"],
        sampling_rate=sampling_rate,
        method=methods["method_quality"],
        **methods["kwargs_quality"]
    )

    # Prepare output
    signals = pd.DataFrame(
        {
            "PPG_Raw": ppg_signal,
            "PPG_Clean": ppg_cleaned,
            "PPG_Rate": rate,
            "PPG_Quality": quality,
            "PPG_Peaks": peaks_signal["PPG_Peaks"].values,
        }
    )

    return signals, info


In [13]:
def load_data(path: str, subject_id: str, sampling_rate: int = 4):
    data_path = Path(path) / subject_id / f'{subject_id}.pkl'
    with open(data_path, 'rb') as file:
        data = pickle.load(file, encoding='latin1')
    
    data_chest = data['signal']['chest']
    data_wrist = data['signal']['wrist']
    data_label = data['label']
    
    for key in data_chest.keys():
        if key == 'ACC':
            # ACC has 3 channels, so we need to resample each channel separately
            temp = []
            for i in range(data_chest[key].shape[1]):
                resampled = nk.signal_resample(data_chest[key][:, i], sampling_rate=SAMPLING_RATES['chest'][key], desired_sampling_rate=SAMPLING_RATE)
                filtered = nk.signal_filter(resampled, sampling_rate=SAMPLING_RATE, lowcut=0.4, method='fir')
                temp.append(filtered)
            temp.append(np.linalg.norm(np.array(temp), axis=0))
            data_chest[key] = np.stack(temp, axis=1)
        else:
            resampled = nk.signal_resample(data_chest[key][:, 0], sampling_rate=SAMPLING_RATES['chest'][key], desired_sampling_rate=SAMPLING_RATE)
            data_chest[key] = resampled
    for key in data_wrist.keys():
        if key == 'ACC':
            # ACC has 3 channels, so we need to resample each channel separately
            temp = []
            for i in range(data_wrist[key].shape[1]):
                temp.append(nk.signal_resample(data_wrist[key][:, i], sampling_rate=SAMPLING_RATES['wrist'][key], desired_sampling_rate=SAMPLING_RATE))
            temp.append(np.linalg.norm(np.array(temp), axis=0))
            data_wrist[key] = np.stack(temp, axis=1)
        else:
            resampled = nk.signal_resample(data_wrist[key][:, 0], sampling_rate=SAMPLING_RATES['wrist'][key], desired_sampling_rate=SAMPLING_RATE)
            data_wrist[key] = resampled

    if SAMPLING_RATES['label'] > sampling_rate:
        data_label = data_label[::SAMPLING_RATES['label'] // sampling_rate]
    elif SAMPLING_RATES['label'] < sampling_rate:
        time = np.round(np.arange(0, len(data_label)) / SAMPLING_RATES['label'], 3)
        f = scipy.interpolate.interp1d(time, data_label, kind='nearest', fill_value='extrapolate')
        time_new = np.arange(0, time[-1], 1 / sampling_rate)
        data_label = f(time_new)

    ecg_chest, _ = ecg_process(data_chest['ECG'], sampling_rate=SAMPLING_RATE)
    emg_chest, _ = emg_process(data_chest['EMG'], sampling_rate=SAMPLING_RATE)
    eda_chest, _ = eda_process(data_chest['EDA'], sampling_rate=SAMPLING_RATE)
    resp_chest, _ = rsp_process(data_chest['Resp'], sampling_rate=SAMPLING_RATE)
    
    bvp_wrist, _ = ppg_process(data_wrist['BVP'], sampling_rate=SAMPLING_RATE)
    eda_wrist, _ = eda_process(data_wrist['EDA'], sampling_rate=SAMPLING_RATE)
    
    acc_chest = pd.DataFrame(data_chest['ACC'], columns=['ACC_x_chest', 'ACC_y_chest', 'ACC_z_chest', 'ACC_net_chest'])
    ecg_chest = ecg_chest.drop(columns=['ECG_Raw']).add_suffix('_chest')
    emg_chest = emg_chest.drop(columns=['EMG_Raw']).add_suffix('_chest')
    eda_chest = eda_chest.drop(columns=['EDA_Raw']).add_suffix('_chest')
    resp_chest = resp_chest.drop(columns=['RSP_Raw']).add_suffix('_chest')
    temp_chest = pd.DataFrame(data_chest['Temp'], columns=['Temp_chest'])

    acc_wrist = pd.DataFrame(data_wrist['ACC'], columns=['ACC_x_wrist', 'ACC_y_wrist', 'ACC_z_wrist', 'ACC_net_wrist'])
    bvp_wrist = bvp_wrist.drop(columns=['PPG_Raw']).add_suffix('_wrist')
    eda_wrist = eda_wrist.drop(columns=['EDA_Raw']).add_suffix('_wrist')
    temp_wrist = pd.DataFrame(data_wrist['TEMP'], columns=['TEMP_wrist'])

    label = pd.DataFrame(data_label, columns=['label'])
    
    # Merge all dataframes on their time indices
    df = pd.concat([acc_chest, ecg_chest, emg_chest, eda_chest, resp_chest, temp_chest,
                    acc_wrist, bvp_wrist, eda_wrist, temp_wrist, label], axis=1, join='outer')
    
    # Filter labels to include only 1 (baseline), 2 (stress), and 3 (amusement)
    df = df[df['label'].isin([1, 2, 3])]
    
    # Reset index to have a clean integer index
    df.reset_index(drop=True, inplace=True)
    
    df = df.assign(
        subject_id=subject_id,
        label = df['label'].map({1: 0, 2: 1, 3: 0}) # Binary classification: 0 (non-stress), 1 (stress)
    )

    return df

In [9]:
def load_all_data(path: str, subject_ids: list, sampling_rate: int = 64):
    all_data = {}
    for subject_id in tqdm(subject_ids):
        df_subject = load_data(path, subject_id, sampling_rate)
        all_data[subject_id] = df_subject
    
    logger.info('All data loaded')
    
    return all_data

In [10]:
def standardize_data(data, mean=None, std=None):
    if mean is None:
        mean = data.mean()
    if std is None:
        std = data.std()
    data_normalized = (data - mean) / std
    return data_normalized

In [11]:
def data_split(all_data, test_subjects: int):
    all_data_list = list(all_data.values())
    n_subjects = len(all_data)
    n_test = test_subjects
    
    n_train = n_subjects - n_test

    df_train = pd.concat(all_data_list[:n_train], ignore_index=True, axis=0)
    df_test = pd.concat(all_data_list[n_train:], ignore_index=True, axis=0)

    df_train_with_anomaly = df_train.reset_index(drop=True)
    df_train = df_train[df_train['label'] == 0].reset_index(drop=True) # Use only non-stress data for training 
    df_test = df_test.reset_index(drop=True) # Use all data for testing
    
    logger.info('Data split into train and test sets')
    logger.info('Train data shape: %s', df_train.shape)
    logger.info('Test data shape: %s', df_test.shape)

    return df_train, df_train_with_anomaly, df_test

In [14]:
all_data = load_all_data('WESAD', SUBJECT_IDS, SAMPLING_RATE)
df_train, df_train_with_anomaly, df_test = data_split(all_data, TEST_SUBJECTS)

  warn(
  warn(
  warn(
100%|██████████| 15/15 [15:59<00:00, 63.99s/it]
INFO:__main__:All data loaded
INFO:__main__:Data split into train and test sets
INFO:__main__:Train data shape: (61696, 69)
INFO:__main__:Test data shape: (44732, 69)


In [None]:
df_train.ffill(inplace=True)
df_train_with_anomaly.ffill(inplace=True)
df_test.ffill(inplace=True)

  df_train.fillna(method='ffill', inplace=True)
  df_train_with_anomaly.fillna(method='ffill', inplace=True)
  df_test.fillna(method='ffill', inplace=True)


In [22]:
Path('data').mkdir(parents=True, exist_ok=True)
df_train.to_csv('data/train.csv', index=False)
df_train_with_anomaly.to_csv('data/train_with_anomaly.csv', index=False)
df_test.to_csv('data/test.csv', index=False)

In [26]:
df_train.label.value_counts()

label
0    61696
Name: count, dtype: int64

In [27]:
df_train_with_anomaly.label.value_counts()

label
0    61696
1    26180
Name: count, dtype: int64

In [28]:
df_test.label.value_counts()

label
0    31048
1    13684
Name: count, dtype: int64