In [91]:
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
with open('data/full_data_corrected_2024.pkl', 'rb') as file:
    full_data = pickle.load(file)

print(type(full_data))

In [None]:
print(full_data.keys())

In [None]:
# Choose an index to inspect (the first sample)
i = 0

# Print out all values for that sample
for key in full_data.keys():
    value = full_data[key][i]
    print(f"{key}: {value}")

This is a full ECG record, including:
* Raw 12-lead signals, each one a NumPy array
* Patient metadata: Sex, HTA, PVC_transition, SOO_chamber, Height, Weight, BMI, DM, DLP, Smoker, COPD, Sleep_apnea, CLINICAL_SCORE, SOO, OTorigin.

In [None]:
# Get the number of samples (smame for all keys)
print(len(full_data['PVC_transition']))
# Convert to a dataframe for easier manipulation

metadata_keys = ['Sex', 'HTA', 'Age', 'PVC_transition', 'SOO_chamber', 'Height', 'Weight', 'BMI', 
                 'DM', 'DLP', 'Smoker', 'COPD', 'Sleep_apnea', 'CLINICAL_SCORE', 'SOO', 'OTorigin']

df_meta = pd.DataFrame({key: full_data[key] for key in metadata_keys})

df_meta.head(10)

In [None]:
print(df_meta['SOO_chamber'].unique())

In [None]:
# Build multi-lead ECG array
ecg_leads = ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
n_samples = len(full_data['I'])
signal_length = len(full_data['I'][0])  # assuming all leads same length

multi_lead_ecgs = np.zeros((n_samples, len(ecg_leads), signal_length))
for i, lead in enumerate(ecg_leads):
    for j in range(n_samples):
        multi_lead_ecgs[j, i, :] = full_data[lead][j]

# Check the shape of the multi-lead ECG array
print(multi_lead_ecgs.shape)  # should be (n_samples, ecg_leads, signal_length)

Hence, there are:
* 181 ECG samples
* 12 leads per sample
* 2500 time points per lead

In [None]:
# Visualize the first ECG sample (12 leads)
sample_idx = 0
fig, axs = plt.subplots(6, 2, figsize=(12, 10))
fig.suptitle(f'ECG Sample {sample_idx}', fontsize=16)

for i, ax in enumerate(axs.flat):
    ax.plot(multi_lead_ecgs[sample_idx, i])
    ax.set_title(ecg_leads[i])
    ax.set_xlim([0, signal_length])
    ax.grid(True)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

In [99]:
import scipy.signal as sp
from scipy.interpolate import interp1d

# Processes one sample ECG with all 12 leads
def preprocess_ecg_signal(ecg_signals, fs=1000, target_fs=250, high=0.5, low=100.0):
    """
    Preprocess a multi-lead ECG signal [timepoints, leads]:
    - Resample to target_fs
    - Bandpass filter between `high` and `low`
    """
    timepoints = ecg_signals.shape[0] # Initially, 2500
    new_timepoints = int(timepoints * target_fs / fs) # 625
    
    # Resample each lead using interpolation
    ecg_resampled = np.zeros((new_timepoints, ecg_signals.shape[1]))
    for lead in range(ecg_signals.shape[1]):
        f = interp1d(np.arange(timepoints), ecg_signals[:, lead]) # Interpolation function from original points
        ecg_resampled[:, lead] = f(np.linspace(0, timepoints - 1, new_timepoints)) # Create the new timeline

    # Apply high-pass filter (remove slow drifts below 0.5 Hz)
    b_high, a_high = sp.butter(2, high / (target_fs / 2), btype='high')
    ecg_filtered = sp.filtfilt(b_high, a_high, ecg_resampled, axis=0)

    # Apply low-pass filter (remove noise above 100 Hz)
    b_low, a_low = sp.butter(2, low / (target_fs / 2), btype='low')
    ecg_filtered = sp.filtfilt(b_low, a_low, ecg_filtered, axis=0)

    return ecg_filtered # Return the signal with shape [625, 12]

In [None]:
# Apply preprocessing to all ECGs
preprocessed_ecgs = []
for i in range(multi_lead_ecgs.shape[0]):
    signal_raw = multi_lead_ecgs[i].T  # shape becomes [2500, 12] 
    processed = preprocess_ecg_signal(signal_raw) # Apply to each sample
    preprocessed_ecgs.append(processed)
    # preprocesses_ecgs becomes a list of arrays, each of shape [625, 12]

# Try stacking into a 3D array
# If all processed signals have identical  shape [625,12], they are stacked
try:
    preprocessed_ecgs = np.stack(preprocessed_ecgs)
    print("All signals successfully preprocessed to shape:", preprocessed_ecgs.shape)
except:
    print("Signals have different lengths. Stored as a list.")

In [None]:
# Compare the raw and preprocessed signals
# Leads that might show more noise or differences after preprocessing
leads_to_plot = ['I', 'AVR', 'V2']
lead_indices = [ecg_leads.index(lead) for lead in leads_to_plot]

# Raw and processed signals
raw_signal = multi_lead_ecgs[i]  # [12, 2500]
if isinstance(preprocessed_ecgs, list):
    processed_signal = preprocessed_ecgs[i].T
else:
    processed_signal = preprocessed_ecgs[i].T  # [12, 625]

# Time axes
t_raw = np.linspace(0, 2.5, raw_signal.shape[1])        # 2500 samples at 1000 Hz
t_processed = np.linspace(0, 2.5, processed_signal.shape[1])  # 625 samples at 250 Hz

# Focus on first second only (for better detail)
max_time = 1.0
raw_mask = t_raw <= max_time
proc_mask = t_processed <= max_time

plt.figure(figsize=(15, 8))

for k, lead_idx in enumerate(lead_indices):
    plt.subplot(len(lead_indices), 1, k+1)
    
    plt.plot(t_raw[raw_mask], raw_signal[lead_idx][raw_mask], label='Raw (1000Hz)', alpha=0.6)
    plt.plot(t_processed[proc_mask], processed_signal[lead_idx][proc_mask], label='Preprocessed (250Hz)', alpha=0.9)
    
    plt.title(f"Lead {ecg_leads[lead_idx]}")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()

In [102]:
# Map SOO to chamber (Left / Right /OTHER)
# Load Hoja1 and Hoja2 from the mapping Excel file
labels_path = "data/labels_FontiersUnsupervised.xlsx"
map_hoja1 = pd.read_excel(labels_path, sheet_name="Hoja1")
map_hoja2 = pd.read_excel(labels_path, sheet_name="Hoja2")

# Build lookup dictionaries
map_1 = dict(zip(map_hoja1["SOO"], map_hoja1["SOO_Chamber"]))
map_2 = dict(zip(map_hoja2["SOO"], map_hoja2["SOO_chamber"]))

# Step 1: Initial mapping using Hoja1
simplified_chambers = []
for entry in full_data["SOO"]:
    if isinstance(entry, str) and entry in map_1:
        simplified_chambers.append(map_1[entry])
    else:
        simplified_chambers.append("OTHER")

# Step 2: Update entries marked as "OTHER" using Hoja2
for i, entry in enumerate(full_data["SOO"]):
    if simplified_chambers[i] == "OTHER" and isinstance(entry, str) and entry in map_2:
        simplified_chambers[i] = map_2[entry]

In [None]:
# Show all unique chamber names after Hoja1 + Hoja2 mapping
unique_chambers = sorted(set(simplified_chambers))
print("Unique chamber labels found:", len(unique_chambers))
for label in unique_chambers:
    print("-", label)

In [104]:
def normalize_chamber(label):
    """
    Normalize known chamber labels to: 'Left', 'Right', or 'OTHER'
    """
    if label in ["RVOT", "Right ventricle", "Tricuspid annulus", "Coronary sinus"]:
        return "Right"
    elif label in ["LVOT", "Left ventricle", "Mitral annulus"]:
        return "Left"
    return "OTHER"

final_chambers_normalized = [normalize_chamber(c) for c in simplified_chambers]

In [None]:
# Plot Histogram for Left vs Right Distribution

# Filter out 'OTHER' samples
filtered_labels = [label for label in final_chambers_normalized if label != "OTHER"]

# Calculate histogram counts
left_count = filtered_labels.count("Left")
right_count = filtered_labels.count("Right")

# Plot histogram for Left vs Right distribution
plt.figure(figsize=(8, 6))
plt.bar(["Left", "Right"], [left_count, right_count], color=['tab:red', 'tab:blue'], edgecolor='black')

plt.title("Class Distribution: Left vs Right Ventricle")
plt.xlabel("Chamber")
plt.ylabel("Number of Samples")
plt.grid(axis='y')  

plt.tight_layout()
plt.show()

# Print counts of each class 
print("Left (0):", left_count)
print("Right (1):", right_count)

With only 40 samples for Left and 140 for Right, our model  might favo the Right class, leading to biased predictions. We will augment the Left class with two data augmentation techniques: Gaussian noise addition and time shifting. Hence, for each Left sample, we will generate 2 augmented versions (via noise and time shifting), keeping the original.

In [None]:
# Data augmentation for Left class
import random

def augment_ecg(ecg, noise_level=0.01, shift_range=10):
    """
    Augment an ECG signal by adding noise and shifting the signal.
    Parameters:
    - ecg: The original ECG signal (numpy array).
    - noise_level: The standard deviation of Gaussian noise to add.
    - shift_range: The range within which to shift the signal (in samples).
    Returns:
    - augmented_versions: A list of augmented ECG signals.
    """
    augmented_versions = []

    # 1. Add Gaussian noise (white noise) to simulate sensor or environment noise
    noise = ecg + np.random.normal(0, noise_level, ecg.shape)
    augmented_versions.append(noise)

    # 2. Time shift (circular roll), creating variation in the signal
    shift = random.randint(-shift_range, shift_range)  # Random shift between -shift_range and +shift_range
    shifted = np.roll(ecg, shift, axis=0)
    augmented_versions.append(shifted)

    return augmented_versions

# Apply Data augmentation to Left-class samples 

# Separate original Left and Right class indices
left_indices = [i for i, label in enumerate(final_chambers_normalized) if label == "Left"]
right_indices = [i for i, label in enumerate(final_chambers_normalized) if label == "Right"]

# Create augmented dataset
X_augmented = []  # To store the augmented ECG signals
y_augmented = []  # To store corresponding labels (0 for Left, 1 for Right)

# Add all Right (label 1)
for idx in right_indices:
    X_augmented.append(preprocessed_ecgs[idx])
    y_augmented.append(1)

# Add original and 2x augmented Left (label 0)
for idx in left_indices:
    original = preprocessed_ecgs[idx]
    X_augmented.append(original)
    y_augmented.append(0)
    for aug in augment_ecg(original):
        X_augmented.append(aug)
        y_augmented.append(0)

# Convert lists to numpy arrays for further processing
X_augmented = np.stack(X_augmented)  # Shape: (n_samples, timepoints, leads)
y_augmented = np.array(y_augmented)  # Shape: (n_samples,)

# Report balance and augmented dataset size
print("\nData Augmentation Complete")
print("X_augmented shape:", X_augmented.shape)
print("Left (0):", np.sum(y_augmented == 0))  # Count how many "Left" samples
print("Right (1):", np.sum(y_augmented == 1))  # Count how many "Right" samples

In [None]:
# Visualize a Left-class ECG and its augmentations for lead V2 (index 7)

def plot_ecg_comparison(original, augmented, lead=7, fs=250):
    t = np.arange(original.shape[0]) / fs
    plt.figure(figsize=(12, 6))
    plt.plot(t, original[:, lead], label="Original", linewidth=2)
    for i, aug in enumerate(augmented):
        plt.plot(t, aug[:, lead], label=f"Augmented {i+1}", linestyle='--')
    plt.title(f"ECG Lead {ecg_leads[lead]} — Original vs Augmented")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

left_example_idx = left_indices[0]
original_signal = preprocessed_ecgs[left_example_idx]
augmented_signals = augment_ecg(original_signal)
plot_ecg_comparison(original_signal, augmented_signals)

In [None]:
# Include the final label (Left, Right, OTHER) as a column in our metadata DataFrame
df_meta["normalized_label"] = final_chambers_normalized
print(df_meta.info())
df_meta.head(10)

In [None]:
df_clean = df_meta[df_meta["normalized_label"].isin(["Left", "Right"])].copy()
df_clean.info()

Note that we have removed the sample labeled as OTHER.

In [None]:
print(df_clean["normalized_label"].unique())

There are many missing values in different columns, so we must decide how to handle them. Since no columns have more than 40% missing values, we decided to impute them.

In [None]:
# Impute numerical columns (median)
numeric_columns = ["Age", "Height", "Weight", "BMI", "CLINICAL_SCORE"]
df_clean[numeric_columns] = df_clean[numeric_columns].fillna(df_clean[numeric_columns].median())  # or use median()

# Impute categorical columns (mode)
categorical_columns = ["Sex", "PVC_transition", "HTA", "DM", "DLP", "Smoker", "COPD", "Sleep_apnea", "OTorigin"]
df_clean[categorical_columns] = df_clean[categorical_columns].apply(lambda x: x.fillna(x.mode()[0]))

# Check remaining missing values
print(df_clean.isnull().sum())