In [36]:
import wfdb
import glob
import os
import tqdm
import pandas as pd
import numpy as np

In [37]:
base_dir = "/Users/taniapazospuig/Desktop/bio/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3/"
csv_path = os.path.join(base_dir, "ptbxl_database.csv")
record_dir = os.path.join(base_dir, "records500")

In [38]:
# Load metadata
variables = pd.read_csv(csv_path, index_col=0)

# Find all .dat files from records500 with the raw ECG signals
files = glob.glob(os.path.join(record_dir, "**", "*.dat"), recursive=True)

# Extract ecg_id from filenames
labels = [os.path.splitext(os.path.basename(f))[0] for f in files]
ecg_ids = [int(label.split("_")[0]) for label in labels]

# Filter metadata to keep only rows for which we have actual ECG waveform files
variables = variables.loc[variables.index.isin(ecg_ids)]

# Reorder filtered metadata to match the order of the waveform files
ordered_indices = [id for id in ecg_ids if id in variables.index]
variables = variables.loc[ordered_indices]

In [39]:
# Shape and preview
print("Shape of variables:", variables.shape) # Rows are ECGs and columns are metadata
variables.head()

Shape of variables: (21799, 27)


Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
20,13619.0,56.0,0,,,2.0,0.0,CS-12 E,1985-01-23 12:55:32,supraventrikulÄre ersatzsystole(n) interponier...,...,True,,,,,VES,,9,records100/00000/00020_lr,records500/00000/00020_hr
771,3063.0,63.0,0,,,10.0,1.0,AT-6 6,1987-05-10 17:22:51,trace only requested.,...,True,,,,,,,9,records100/00000/00771_lr,records500/00000/00771_hr
297,4845.0,73.0,1,170.0,103.0,1.0,1.0,AT-6 C 5.5,1986-09-12 10:22:10,premature ventricular contraction(s). sinus rh...,...,True,,,,,,,7,records100/00000/00297_lr,records500/00000/00297_hr
120,11860.0,45.0,1,,57.0,2.0,0.0,CS-12 E,1986-01-16 06:41:58,sinusrhythmus normales ekg,...,True,,", alles,",,,,,4,records100/00000/00120_lr,records500/00000/00120_hr
671,3977.0,76.0,1,167.0,45.0,3.0,1.0,AT-6 6,1987-04-25 19:35:42,sinus rhythm. normal ecg.,...,True,,,,,,,3,records100/00000/00671_lr,records500/00000/00671_hr


To ensure the dataset is ready for model training, we first handled the missing values in the metadata. Columns with more than 50% missing values (`electrodes_problems`, `infarction_stadium1`, `infarction_stadium2`, `pacemaker`, etc.) or low relevance (`nurse`, `site`, `device`,etc.) were dropped as they offered little value for the classification task and would introduce unnecessary noise.

For columns with moderate missingness but potential predictive value, such as `height` and `weight`, we imputed missing values using the median, which is a standard and robust method for handling numerical missing data.

Categorical columns, such as `sex` and `report`, had their missing values imputed using the mode, ensuring that the dataset remained complete and consistent.

In [40]:
# TODO: Give credit to https://github.com/huseyincavusbi/SE_ECGNet/blob/main/SE_ECGNet.ipynb
# Identify columns with more than 50% missing values
missing_percentages = (variables.isnull().sum() / len(variables)) * 100
high_missing_cols = missing_percentages[missing_percentages > 50].index.tolist()

# Define additional low-relevance columns to drop manually
additional_cols_to_drop = ['nurse', 'site', 'device', 'recording_date', 'validated_by']

# Combine both lists of columns to drop
cols_to_drop = high_missing_cols + additional_cols_to_drop

# Print columns to be dropped
print("Columns being dropped due to missing or low relevance:")
print("-" * 50)
for col in cols_to_drop:
    if col in variables.columns:
        print(f"{col}: {missing_percentages.get(col, 0):.2f}% missing")

# Drop the identified columns
variables_cleaned = variables.drop(columns=cols_to_drop)

# Handle remaining missing values
# Fill numeric columns with median
numeric_columns = variables_cleaned.select_dtypes(include=['float64', 'int64']).columns
variables_cleaned[numeric_columns] = variables_cleaned[numeric_columns].fillna(variables_cleaned[numeric_columns].median())

# Fill categorical columns with mode
categorical_columns = variables_cleaned.select_dtypes(include=['object']).columns
variables_cleaned[categorical_columns] = variables_cleaned[categorical_columns].fillna(variables_cleaned[categorical_columns].mode().iloc[0])

print("\nFinal cleaned metadata info:")
print(variables_cleaned.info())

Columns being dropped due to missing or low relevance:
--------------------------------------------------
height: 68.01% missing
weight: 56.78% missing
infarction_stadium1: 74.26% missing
infarction_stadium2: 99.53% missing
baseline_drift: 92.67% missing
static_noise: 85.05% missing
burst_noise: 97.19% missing
electrodes_problems: 99.86% missing
extra_beats: 91.06% missing
pacemaker: 98.67% missing
nurse: 6.76% missing
site: 0.08% missing
device: 0.00% missing
recording_date: 0.00% missing
validated_by: 43.02% missing

Final cleaned metadata info:
<class 'pandas.core.frame.DataFrame'>
Index: 21799 entries, 20 to 17905
Data columns (total 12 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   patient_id                    21799 non-null  float64
 1   age                           21799 non-null  float64
 2   sex                           21799 non-null  int64  
 3   report                        21799 non-n

In [41]:
print("\nMissing values per column:")
print(variables_cleaned.isnull().sum())


Missing values per column:
patient_id                      0
age                             0
sex                             0
report                          0
scp_codes                       0
heart_axis                      0
second_opinion                  0
initial_autogenerated_report    0
validated_by_human              0
strat_fold                      0
filename_lr                     0
filename_hr                     0
dtype: int64


In the PTB-XL dataset, patient age is provided at the time of ECG recording. However, in compliance with HIPAA privacy standards, all patients older than 89 years are assigned a value of 300. This is a form of pseudonymization to prevent potential re-identification of elderly individuals. Since this value does not represent a real age and could skew the model or statistical summaries, we cap all age values at 89.

In [42]:
print("Sex distribution:\n", variables_cleaned["sex"].value_counts(), "\n")
print("Age summary:\n", variables_cleaned["age"].describe())

Sex distribution:
 sex
0    11354
1    10445
Name: count, dtype: int64 

Age summary:
 count    21799.000000
mean        62.769301
std         32.308813
min          2.000000
25%         50.000000
50%         62.000000
75%         72.000000
max        300.000000
Name: age, dtype: float64


In [43]:
# Cap age at 89
variables_cleaned["age"] = variables_cleaned["age"].apply(lambda x: 89 if x == 300 else x)

In [44]:
variables_cleaned["scp_codes"]

ecg_id
20                 {'AFLT': 100.0, 'ABQRS': 0.0}
771                              {'NORM': 100.0}
297        {'NORM': 80.0, 'PVC': 0.0, 'SR': 0.0}
120                   {'NORM': 100.0, 'SR': 0.0}
671                   {'NORM': 100.0, 'SR': 0.0}
                          ...                   
17141                  {'NDT': 100.0, 'SR': 0.0}
17710                  {'NORM': 80.0, 'SR': 0.0}
17041                {'CLBBB': 100.0, 'SR': 0.0}
17805    {'NORM': 80.0, 'HVOLT': 0.0, 'SR': 0.0}
17905                  {'NDT': 100.0, 'SR': 0.0}
Name: scp_codes, Length: 21799, dtype: object

Multi-label classification is required for this task because each ECG record can be associated with multiple diagnostic superclasses (`NORM`, `MI`, `STTC`, `CD`, `HYP`). To simplify the task and focus on the required classes for this assignment, we decided to focus on the three target labels: `NORM`, `MI`, and `STTC`.

To ensure the dataset only contains records relevant to these classes, we filtered out any records that did not include at least one of the three target labels (`NORM`, `MI`, `STTC`), ensuring that the model will only be trained and evaluated on these three superclasses.

In [48]:
from ast import literal_eval

def safe_literal_eval(val):
    if isinstance(val, str):
        return literal_eval(val)
    return val  # Already a dict, no need to convert

variables_cleaned["scp_codes"] = variables_cleaned["scp_codes"].apply(safe_literal_eval)

In [49]:
# Load statement reference table
scp_df = pd.read_csv(os.path.join(base_dir, "scp_statements.csv"), index_col=0)

# Keep only rows with a diagnostic_class
scp_diagnostic_map = scp_df[scp_df["diagnostic_class"].notnull()]["diagnostic_class"].to_dict()

# Map each scp_codes dict to diagnostic superclasses
def map_to_superclasses(scp_code_dict):
    return list({scp_diagnostic_map[code] for code in scp_code_dict if code in scp_diagnostic_map})

variables_cleaned["diagnostic_superclass_mapped"] = variables_cleaned["scp_codes"].apply(map_to_superclasses)

variables_cleaned[["scp_codes", "diagnostic_superclass_mapped"]].head(10)

Unnamed: 0_level_0,scp_codes,diagnostic_superclass_mapped
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1
20,"{'AFLT': 100.0, 'ABQRS': 0.0}",[]
771,{'NORM': 100.0},[NORM]
297,"{'NORM': 80.0, 'PVC': 0.0, 'SR': 0.0}",[NORM]
120,"{'NORM': 100.0, 'SR': 0.0}",[NORM]
671,"{'NORM': 100.0, 'SR': 0.0}",[NORM]
397,"{'NORM': 100.0, 'SR': 0.0}",[NORM]
964,"{'ASMI': 100.0, 'ILBBB': 100.0, 'LVH': 100.0, ...","[HYP, MI, STTC, CD]"
864,"{'ISCIN': 100.0, '1AVB': 100.0, 'PVC': 100.0, ...","[STTC, CD]"
919,"{'SEHYP': 50.0, 'ISCAS': 100.0, 'INVT': 0.0, '...","[HYP, STTC]"
819,"{'NORM': 100.0, 'SR': 0.0}",[NORM]


In [57]:
# Define the 3 target labels
target_labels = {"NORM", "MI", "STTC"}

# Keep records that have at least one of the target labels
variables_filtered = variables_cleaned[variables_cleaned["diagnostic_superclass_mapped"].apply(lambda x: bool(set(x) & target_labels))]

variables_filtered[["scp_codes", "diagnostic_superclass_mapped"]].head(10)

Unnamed: 0_level_0,scp_codes,diagnostic_superclass_mapped
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1
771,{'NORM': 100.0},[NORM]
297,"{'NORM': 80.0, 'PVC': 0.0, 'SR': 0.0}",[NORM]
120,"{'NORM': 100.0, 'SR': 0.0}",[NORM]
671,"{'NORM': 100.0, 'SR': 0.0}",[NORM]
397,"{'NORM': 100.0, 'SR': 0.0}",[NORM]
964,"{'ASMI': 100.0, 'ILBBB': 100.0, 'LVH': 100.0, ...","[HYP, MI, STTC, CD]"
864,"{'ISCIN': 100.0, '1AVB': 100.0, 'PVC': 100.0, ...","[STTC, CD]"
919,"{'SEHYP': 50.0, 'ISCAS': 100.0, 'INVT': 0.0, '...","[HYP, STTC]"
819,"{'NORM': 100.0, 'SR': 0.0}",[NORM]
289,"{'NORM': 80.0, 'VCLVH': 0.0, 'SBRAD': 0.0}",[NORM]


The multi-hot encoding approach was used to represent the target labels for each ECG. `MultiLabelBinarizer` was applied to convert the `diagnostic_superclass_mapped` column (which contains lists of superclasses) into a binary matrix, where each row represents an ECG and each column corresponds to the presence (1) or absence (0) of one of the target labels. Labels such as `CD` and `HYP` were intentionally excluded from this encoding, as they are outside the scope of the task and should not influence model predictions or evaluation.

This approach allows the model to learn which of the target classes (`NORM`, `MI`, `STTC`) are associated with each ECG, even if multiple labels are present.

In [58]:
from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer(classes=["NORM", "MI", "STTC"])
y = mlb.fit_transform(variables_filtered["diagnostic_superclass_mapped"])



To ensure proper evaluation and generalization of the classification models, we followed the official 10-fold stratified split provided by the PTB-XL dataset authors. This split keeps all records from the same patient within the same fold, avoiding data leakage.

Specifically, we used:
- Folds 1–8 for training
- Fold 9 for validation
- Fold 10 for testing

We then loaded the corresponding ECG signal data using the paths provided in the `filename_hr` column. Signals were loaded at a sampling rate of 500 Hz, preserving full resolution for better model performance.
The label vectors (`y_train`, `y_val`, `y_test`) were already multi-hot encoded using the `MultiLabelBinarizer`, allowing the model to learn from ECGs with multiple diagnostic labels (`NORM`, `MI`, `STTC`).

In [None]:
fs = 500

# Load ECG signal data
def load_raw_data(df, fs, base_path):
    if fs == 100:
        paths = df["filename_lr"]
    else:
        paths = df["filename_hr"]
    signals = []
    for f in paths:
        full_path = os.path.join(base_path, f)
        signal, _ = wfdb.rdsamp(full_path)
        signals.append(signal)
    return np.array(signals)

# Split variables_filtered by strat_fold
train_df = variables_filtered[variables_filtered["strat_fold"] < 9]
val_df   = variables_filtered[variables_filtered["strat_fold"] == 9]
test_df  = variables_filtered[variables_filtered["strat_fold"] == 10]

# Load raw signal data
X_train = load_raw_data(train_df, fs, base_dir)
X_val   = load_raw_data(val_df, fs, base_dir)
X_test  = load_raw_data(test_df, fs, base_dir)

# Get binarized labels using the same mlb
y_train = mlb.transform(train_df["diagnostic_superclass_mapped"])
y_val   = mlb.transform(val_df["diagnostic_superclass_mapped"])
y_test  = mlb.transform(test_df["diagnostic_superclass_mapped"])



We evaluated the quality of ECG signals using the Signal-to-Noise Ratio (SNR) across all 12 leads in the training set.
Since raw ECGs do not have clean ground truth references, we computed:
- Pre-filter SNR by estimating signal power vs. variance (as a proxy for noise)
- Post-filter SNR by treating the filtered signal as the clean reference and computing noise as the difference between the raw and filtered signals

The results show a consistent and significant improvement in SNR after applying standard filtering (0.5 Hz high-pass, 45 Hz low-pass, and detrending).

In [63]:
batch_size = 100
num_leads = X_train.shape[2]

# Compute SNR of raw signals (before filtering)
# Before filtering, we do not have a clean signal to compare to
# We use the raw signal to compute its power, and its variance to estimate noise
def compute_raw_snr_batch(X):
    signal_power_total = np.zeros(num_leads)
    noise_power_total = np.zeros(num_leads)

    for i in range(0, len(X), batch_size):
        batch = X[i:i+batch_size]

        for signal_raw in batch:
            # Compute average signal power per lead
            signal_power = np.mean(np.square(signal_raw), axis=0)

            # Estimate noise power per lead as the variance (std^2)
            noise_power = np.var(signal_raw, axis=0)

            signal_power_total += signal_power
            noise_power_total += noise_power

    # Average over all ECGs
    P_signal_avg = signal_power_total / len(X)
    P_noise_avg = noise_power_total / len(X)

    # Compute SNR and convert to decibels
    snr = P_signal_avg / P_noise_avg
    snr_db = 10 * np.log10(snr)

    return snr, snr_db

In [64]:
from scipy import signal

# Filtering function
def apply_filters(ecg, fs):
    # High-pass filter
    ecg = signal.filtfilt(*signal.butter(2, 0.5, 'high', fs=fs), ecg, axis=0)
    # Low-pass filter
    ecg = signal.filtfilt(*signal.butter(2, 45.0, 'low', fs=fs), ecg, axis=0)
    # Remove linear trend
    ecg = signal.detrend(ecg, axis=0)
    return ecg

# Compute SNR in batches to avoid memory issues
def compute_power_snr_batch(X, fs):
    signal_power_total = np.zeros(num_leads)
    noise_power_total = np.zeros(num_leads)
    
    for i in range(0, len(X), batch_size):
        batch = X[i:i+batch_size]
        
        # Split the data into batches
        for signal_raw in batch:
            signal_filtered = apply_filters(signal_raw, fs)
            noise = signal_raw - signal_filtered

            # Compute average power per lead for clean and noisy signals
            signal_power = np.mean(np.square(signal_filtered), axis=0)
            noise_power = np.mean(np.square(noise), axis=0)

            signal_power_total += signal_power
            noise_power_total += noise_power

    # Average over all processed signals
    P_signal_avg = signal_power_total / len(X)
    P_noise_avg = noise_power_total / len(X)
    snr = P_signal_avg / P_noise_avg
    snr_db = 10 * np.log10(snr)

    return snr, snr_db

In [67]:
# Compute SNR before filtering
snr_raw, snr_raw_db = compute_raw_snr_batch(X_train)

# Compute SNR after filtering
snr_filtered, snr_filtered_db = compute_power_snr_batch(X_train, fs=500)

# Display the results per lead
print(f"{'Lead':<6}{'Raw SNR (dB)':<15}{'Filtered SNR (dB)':<20}{'Delta SNR (dB)':<10}")
print("-" * 50)
for i in range(num_leads):
    delta = snr_filtered_db[i] - snr_raw_db[i]
    print(f"{i+1:<6}{snr_raw_db[i]:<15.2f}{snr_filtered_db[i]:<20.2f}{delta:<10.2f}")

Lead  Raw SNR (dB)   Filtered SNR (dB)   Delta SNR (dB)
--------------------------------------------------
1     0.20           1.62                1.42      
2     0.09           6.04                5.96      
3     0.30           2.77                2.47      
4     0.15           6.34                6.19      
5     0.35           3.07                2.72      
6     0.16           4.09                3.93      
7     0.02           4.84                4.81      
8     0.01           7.93                7.93      
9     0.01           6.50                6.49      
10    0.00           7.38                7.38      
11    0.39           3.56                3.17      
12    0.01           2.68                2.67      


In [68]:
# Store filtered signals for later steps
X_train_filtered = np.array([apply_filters(x, fs=500) for x in X_train])
X_val_filtered   = np.array([apply_filters(x, fs=500) for x in X_val])
X_test_filtered  = np.array([apply_filters(x, fs=500) for x in X_test])