In [11]:
import torch
import pickle
import numpy as np
import pandas as pd
import skimage.util
import math
import sak

In [12]:
# Load pretrained SAK models
model_dir = "data/modelos"
models = [torch.load(f"{model_dir}/model.{i+1}") for i in range(5)]
print("Loaded SAK models.")

Loaded SAK models.


  models = [torch.load(f"{model_dir}/model.{i+1}") for i in range(5)]


In [13]:
# Load the preprocessed signal array
with open("processed_data_big_dataset/df_signals_preprocessed.pkl", "rb") as f:
    df_signals = pickle.load(f)

print(f"Loaded df_signals with shape: {df_signals.shape}")

Loaded df_signals with shape: (29153, 18)


In [14]:
# Load preprocessed_ecgs.pkl
with open("processed_data_big_dataset/preprocessed_ecgs.pkl", "rb") as f:
    ecg_signals_all = pickle.load(f)

In [15]:
def predict_ecg(ecg, fs=250, model=None, window_size=2048, stride=256, threshold_ensemble=0.5,
                thr_dice=0.9, ptg_voting=0.5, batch_size=16):
    
    if ecg.shape[0] < 50: # Skip very short signals
        raise ValueError(f"Signal too short for segmentation: {ecg.shape}")
    
    # Make sure shape is [time, leads]
    ecg = np.copy(ecg)
    if ecg.ndim == 2 and ecg.shape[0] < ecg.shape[1]:
        ecg = ecg.T
    ecg = ecg[:, :12]  # Only first 12 leads

    # Pad if needed
    N = ecg.shape[0]
    # Pad to make the length a multiple of window size
    if N < window_size:
        pad = math.ceil(N / window_size) * window_size - N
        ecg = np.pad(ecg, ((0, pad), (0, 0)), mode='edge')
    # Also make sure the overlapping windows line up with the stride
    if (ecg.shape[0] - window_size) % stride != 0:
        pad = math.ceil((ecg.shape[0] - window_size) / stride) * stride - (ecg.shape[0] % window_size)
        ecg = np.pad(ecg, ((0, pad), (0, 0)), mode='edge')

    # Windowing
    windowed = skimage.util.view_as_windows(ecg, (window_size, ecg.shape[1]), step=(stride, 1))
    windowed = windowed[:, 0, :, :]  # Remove the singleton dimension (n_windows, 2048, 12)
    windowed = np.swapaxes(windowed, 1, 2) # Shape becomes (n_windows, 12, 2048)

    # Predict with models
    # Each model outputs a prediction for 3 channels (P, QRS, T)
    mask = np.zeros((windowed.shape[0], 3, window_size), dtype=int)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Loop over all models and batches
    with torch.no_grad():
        for m in model:
            m = m.to(device)
            for i in range(0, windowed.shape[0], batch_size):
                inputs = {"x": torch.tensor(windowed[i:i+batch_size]).float().to(device)}
                outputs = m(inputs)["sigmoid"].cpu().numpy()
                # Apply threshold to convert probabilities to binary
                # Accumulate votes (how many models say "yes" to a region)
                mask[i:i+batch_size] += outputs > thr_dice
        # After all models vote, keep only regions where enough models agreed
        mask = mask >= len(model) * threshold_ensemble

    # Reconstruct full signal
    full_mask = np.zeros((3, ecg.shape[0]))
    counter = np.zeros(ecg.shape[0])
    # For each window, place its mask into the right location in the full signal
    for i in range(0, mask.shape[0]):
        start = i * stride
        full_mask[:, start:start+window_size] += mask[i]
        counter[start:start+window_size] += 1 # Count how many times each timepoint has been covered since windows overlap
    full_mask = (full_mask / counter) > ptg_voting # Normalize votes and apply final threshold for each timepoint
    full_mask = full_mask[:, :N] # Trim to original signal length

    # Clean up extra dimensions
    if full_mask.ndim == 3 and full_mask.shape[-1] == 1:
        full_mask = full_mask.squeeze(-1)

    return full_mask

In [16]:
def extract_morph_features(signal, mask, fs=250):
    """
    Extract morphological features from a single-lead ECG and its segmentation mask.

    Parameters:
        signal (np.ndarray): ECG signal of shape (T, 1) or (T,)
        mask (np.ndarray): Segmentation mask of shape (3, T)
        fs (int): Sampling frequency in Hz

    Returns:
        dict: Morphological features extracted from the lead
    """
    features = {}

    # Ensure 1D signal
    if signal.ndim == 2 and signal.shape[1] == 1:
        lead_signal = signal[:, 0]
    elif signal.ndim == 1:
        lead_signal = signal
    else:
        raise ValueError(f"Unexpected signal shape: {signal.shape}")

    # Check mask shape
    if mask.shape[0] != 3 or mask.shape[1] != lead_signal.shape[0]:
        raise ValueError(f"Signal or mask malformed: signal shape {lead_signal.shape}, mask shape {mask.shape}")

    # R and S amplitudes
    r_peak = np.max(lead_signal)
    s_trough = np.min(lead_signal)
    r_s_ratio = r_peak / abs(s_trough) if s_trough != 0 else 0

    # QRS duration (in milliseconds)
    qrs_indices = np.where(mask[1])[0]
    qrs_dur = (qrs_indices[-1] - qrs_indices[0]) / fs * 1000 if len(qrs_indices) > 1 else 0

    # T wave polarity
    t_indices = np.where(mask[2])[0]
    if len(t_indices) > 3:
        t_mean = np.mean(lead_signal[t_indices])
        polarity = 1 if t_mean > 0.02 else (-1 if t_mean < -0.02 else 0)
    else:
        polarity = 0

    features["r_amp"] = r_peak
    features["s_amp"] = s_trough
    features["r_s_ratio"] = r_s_ratio
    features["qrs_dur"] = qrs_dur
    features["t_polarity"] = polarity

    return features

In [17]:
# ECG Feature Extraction After Averaging Per Patient
from tqdm import tqdm
import matplotlib.pyplot as plt

# Assume we already have the ECGs loaded as a list
# ecg_signals_all[i] corresponds to df_signals.iloc[i]['SampleID']

# Group ECGs by PatientID
patient_ecgs = {}
for i, row in df_signals.iterrows():
    pid = row.PatientID
    if pid not in patient_ecgs:
        patient_ecgs[pid] = []
    patient_ecgs[pid].append(ecg_signals_all[i])

# Average ECGs per patient (with shape validation)
patient_avg_ecgs = {}
for pid, ecgs in tqdm(patient_ecgs.items(), desc="Averaging ECGs per patient"):
    try:
        shapes = [ecg.shape for ecg in ecgs]
        if len(set(shapes)) > 1:
            print(f"Skipping patient {pid} due to mismatched ECG shapes: {set(shapes)}")
            continue
        stacked = np.stack(ecgs, axis=0)
        avg_ecg = np.mean(stacked, axis=0)
        patient_avg_ecgs[pid] = avg_ecg
    except Exception as e:
        print(f"Skipping patient {pid} due to error: {e}")

# Visualize average vs original ECGs for one patient
def plot_patient_avg_vs_originals(pid, lead_index=6):
    if pid not in patient_ecgs or pid not in patient_avg_ecgs:
        print(f"Patient {pid} not found or not averaged.")
        return
    originals = patient_ecgs[pid]
    avg = patient_avg_ecgs[pid][:, lead_index]
    plt.figure(figsize=(12, 6))
    for i, ecg in enumerate(originals):
        plt.plot(ecg[:, lead_index], alpha=0.3, label=f"ECG {i+1}")
    plt.plot(avg, color='black', linewidth=2, label="Averaged")
    plt.title(f"Patient {pid} - Lead index {lead_index}")
    plt.legend()
    plt.grid(True)
    plt.show()

# Feature extraction per patient average ECG 
features_list = []
for pid, signal in tqdm(patient_avg_ecgs.items(), desc="Extracting features"):
    feats_all_leads = {}
    try:
        for j, lead_name in enumerate(["I", "II", "III", "AVR", "AVL", "AVF", "V1", "V2", "V3", "V4", "V5", "V6"]):
            lead = signal[:, j]
            if lead.ndim != 1 or lead.shape[0] < 50:
                raise ValueError(f"Invalid lead shape: {lead.shape}")
            lead = lead[:, np.newaxis]
            mask = predict_ecg(lead, model=models)
            if mask.ndim == 3:
                mask = mask.squeeze(-1)
            feats = extract_morph_features(lead, mask)
            for k, v in feats.items():
                feats_all_leads[f"{lead_name}_{k}"] = v
        feats_all_leads["PatientID"] = pid
        features_list.append(feats_all_leads)
    except Exception as e:
        print(f"Skipping patient {pid} due to error: {e}")

# Save results
df_feats = pd.DataFrame(features_list)
df_feats.to_csv("features_per_patient_avg.csv", index=False)
print("Saved features_per_patient_avg.csv with shape:", df_feats.shape)

Averaging ECGs per patient: 100%|██████████| 181/181 [00:00<00:00, 282.43it/s]
Extracting features: 100%|██████████| 181/181 [05:27<00:00,  1.81s/it]

Saved features_per_patient_avg.csv with shape: (181, 61)



