In [41]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [42]:
import pandas as pd
import numpy as np

#ecg_afib_df=pd.read_csv('/content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/42713_Afib_ECG_adm_cohort.csv')
#ecg_afib_df.head()


#Signals embeddings

In [43]:
!pip -q install wfdb scipy pandas numpy tqdm torch_ecg --progress-bar off

In [44]:
import os, numpy as np, pandas as pd, wfdb, torch, torchvision
from tqdm import tqdm
from scipy.signal import resample_poly
from torch import nn
from PIL import Image
import torchvision.transforms as T

In [45]:
WAVE_ROOT = "/content/drive/MyDrive/MIMIC_IV_Pipeline/ecg_signals/ecg_500_subset"     #path that contains signal files'
META_CSV  = "/content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/42713_Afib_ECG_adm_cohort.csv" #metadata for the specific cohort
OUT_CSV   = "/content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/ecg_embeddings_500_sampling.csv"#adding the output path

df = pd.read_csv(META_CSV).copy()

In [46]:
#def record_id_from_waveform_path(wp: str) -> str:
#    return os.path.basename(str(wp).strip("/"))
df["record_id"] = df["waveform_path"].astype(str).str.strip("/").str.split("/").str[-1]

In [47]:
def resolve_abs_base_from_record_id(rid: str):
    base = os.path.join(WAVE_ROOT, rid)
    if os.path.exists(base + ".hea") and os.path.exists(base + ".dat"):
        return base
    if os.path.exists(base + ".hea.gz") and os.path.exists(base + ".dat.gz"):
        return base
    return None

df["abs_base"] = df["record_id"].apply(resolve_abs_base_from_record_id)
print("Found local files for", df["abs_base"].notna().sum(), "of", len(df), "rows")

Found local files for 500 of 176153 rows


In [48]:
df_match = df[df["abs_base"].notna()].copy().reset_index(drop=True)

In [49]:
#take only with those that match with the downloaded signals
df_match = df[df["abs_base"].notna()].copy().reset_index(drop=True)#this abs_base is storing the absolute base path for the local ecg files (both- hea and dat)
if len(df_match) == 0:
    raise RuntimeError("No matching WFDB files were found. Check WAVE_ROOT and filenames.")


In [50]:
def load_wfdb_record(abs_base):
    rec = wfdb.rdrecord(abs_base)            #.hea adds .dat
    sig = rec.p_signal.astype(np.float32)    # (n_samples, n_leads)
    fs  = float(rec.fs)
    return sig, fs

def to_12_leads(sig):
    n, L = sig.shape
    if L == 12: return sig
    out = np.zeros((n, 12), dtype=np.float32)
    out[:, :min(L,12)] = sig[:, :min(L,12)]
    return out

def fix_length_resample(sig, fs, target_fs=500, target_sec=10):
    if abs(fs - target_fs) > 1e-3:
        sig = resample_poly(sig, up=int(target_fs), down=int(fs), axis=0)
    n_target = int(target_fs * target_sec)
    if sig.shape[0] >= n_target:
        s = (sig.shape[0] - n_target) // 2
        sig = sig[s:s+n_target, :]
    else:
        pad = n_target - sig.shape[0]
        left, right = pad // 2, pad - pad // 2
        sig = np.pad(sig, ((left, right), (0, 0)), mode="constant")
    return sig  # (5000, 12)

def ecg_to_image(sig_txL):
    # z-score per lead, clip, min-max to 0..255
    x = (sig_txL - sig_txL.mean(axis=0, keepdims=True)) / (sig_txL.std(axis=0, keepdims=True) + 1e-9)
    x = np.clip(x, -5, 5)
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) #Handle potential NaNs or Infs
    x = (x - x.min()) / (x.max() - x.min() + 1e-9)
    x = (x * 255.0).astype(np.uint8)         # (5000, 12)
    img = x.T                                # (12, 5000) → HxW
    pil = Image.fromarray(img).resize((224, 224), Image.BILINEAR).convert("RGB")
    return pil

#DenseNet feature extraction
device = "cuda" if torch.cuda.is_available() else "cpu"
backbone = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
feature_extractor = nn.Sequential(
    backbone.features,
    nn.ReLU(inplace=True),
    nn.AdaptiveAvgPool2d((1,1)),
    nn.Flatten(),                            # -> 1024-D
).to(device).eval()

img_tf = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225]),
])

# ------------------ RESNET50 ------------------
resnet = torchvision.models.resnet50(
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2
)
resnet_feature = nn.Sequential(
    *(list(resnet.children())[:-1]),   # remove FC → output (2048,1,1)
    nn.Flatten(),                      # → (2048,)
).to(device).eval()

# ---------------- EFFICIENTNET-B0 -------------
effnet = torchvision.models.efficientnet_b0(
    weights=torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1
)
effnet_feature = nn.Sequential(
    effnet.features,
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),                      # → (1280,)
).to(device).eval()

# ----------------- CONVNEXT-TINY -------------
convnext = torchvision.models.convnext_tiny(
    weights=torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
)
convnext_feature = nn.Sequential(
    convnext.features,
    nn.AdaptiveAvgPool2d((1,1)),
    nn.Flatten(),                      # → (768,)
).to(device).eval()

from transformers import AutoImageProcessor, AutoModel
import torch

dino_processor = AutoImageProcessor.from_pretrained(
    "facebook/dinov2-small"
)
dino_model = AutoModel.from_pretrained(
    "facebook/dinov2-small"
).to(device).eval()

# ================================
# SOTA MODELS
# ================================

import timm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# DINOv2 ViT-Small
dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small")
dino_model = AutoModel.from_pretrained("facebook/dinov2-small").to(device).eval()


#  Swin Transformer Small (768-D)

swin = timm.create_model(
    "swin_small_patch4_window7_224",
    pretrained=True,
    num_classes=0
).to(device).eval()

# ConvNeXt Base (1024-D)
convnext_base = timm.create_model(
    "convnext_base",
    pretrained=True,
    num_classes=0
).to(device).eval()

# EfficientNet-V2 Small (1280-D)
effnetv2s = timm.create_model(
    "tf_efficientnetv2_s_in21k",
    pretrained=True,
    num_classes=0
).to(device).eval()


def embed_with_all_models(pil_img):

    x = img_tf(pil_img)[None].to(device)
    out = {}

    with torch.no_grad():

        # DenseNet121
        f_dn = feature_extractor(x).squeeze()
        out["densenet"] = torch.nn.functional.normalize(f_dn, p=2, dim=0).cpu().numpy()

        # ResNet50
        f_rs = resnet_feature(x).squeeze()
        out["resnet50"] = torch.nn.functional.normalize(f_rs, p=2, dim=0).cpu().numpy()

        # EfficientNet-B0
        f_effb0 = effnet_feature(x).squeeze()
        out["efficientnet_b0"] = torch.nn.functional.normalize(f_effb0, p=2, dim=0).cpu().numpy()

        # ConvNeXt-Tiny
        f_cnt = convnext_feature(x).squeeze()
        out["convnext_tiny"] = torch.nn.functional.normalize(f_cnt, p=2, dim=0).cpu().numpy()

        # DINOv2
        inputs = dino_processor(images=pil_img, return_tensors="pt").to(device)
        feats = dino_model(**inputs).last_hidden_state[:, 0, :]
        out["dinov2"] = torch.nn.functional.normalize(feats.squeeze(), p=2, dim=0).cpu().numpy()

        # Swin Small
        f_swin = swin(x).squeeze()
        out["swin_small"] = torch.nn.functional.normalize(f_swin, p=2, dim=0).cpu().numpy()

        # ConvNeXt Base
        f_cnb = convnext_base(x).squeeze()
        out["convnext_base"] = torch.nn.functional.normalize(f_cnb, p=2, dim=0).cpu().numpy()

        # EffNetV2-S
        f_e2 = effnetv2s(x).squeeze()
        out["effnetv2_s"] = torch.nn.functional.normalize(f_e2, p=2, dim=0).cpu().numpy()

    return out


def embed_record(abs_base):
    sig, fs = load_wfdb_record(abs_base)
    sig = to_12_leads(sig)
    sig = fix_length_resample(sig, fs, 500, 10)
    pil = ecg_to_image(sig)

    feats = embed_with_all_models(pil)

    return feats


#mat->for maching
# initialize lists
densenet_emb = []
resnet_emb = []
effb0_emb = []
convnext_tiny_emb = []
dinov2_emb = []
swin_emb = []
convnext_base_emb = []
effv2s_emb = []

mat, miss = 0, 0

for _, row in tqdm(df_match.iterrows(),
                   total=len(df_match),
                   desc="ECG → embeddings (8 models)"):
    try:
        vecs = embed_record(row["abs_base"])

        densenet_emb.append(vecs["densenet"])
        resnet_emb.append(vecs["resnet50"])
        effb0_emb.append(vecs["efficientnet_b0"])
        convnext_tiny_emb.append(vecs["convnext_tiny"])

        dinov2_emb.append(vecs["dinov2"])
        swin_emb.append(vecs["swin_small"])
        convnext_base_emb.append(vecs["convnext_base"])
        effv2s_emb.append(vecs["effnetv2_s"])

        mat += 1
    except Exception as e:
        densenet_emb.append(None)
        resnet_emb.append(None)
        effb0_emb.append(None)
        convnext_tiny_emb.append(None)

        dinov2_emb.append(None)
        swin_emb.append(None)
        convnext_base_emb.append(None)
        effv2s_emb.append(None)

        miss += 1


Using device: cpu


  model = create_fn(
ECG → embeddings (8 models): 100%|██████████| 500/500 [21:56<00:00,  2.63s/it]


In [51]:
df_match["emb_densenet"]       = densenet_emb
df_match["emb_resnet50"]       = resnet_emb
df_match["emb_efficientnet_b0"]= effb0_emb
df_match["emb_convnext_tiny"]  = convnext_tiny_emb

df_match["emb_dinov2"]         = dinov2_emb
df_match["emb_swin_small"]     = swin_emb
df_match["emb_convnext_base"]  = convnext_base_emb
df_match["emb_effnetv2_s"]     = effv2s_emb

print(f"Embedded {mat}/{len(df_match)}; failed {miss}")

Embedded 500/500; failed 0


In [52]:
#keep only the columns needed for the merge (record_id + embedding)
# removed ecg_map creation and merge as df_match already contains the necessary data.

#Use df_match directly, which now includes the ecg_emb_dl column
out = df_match.copy()

out = out.drop(columns=["abs_base"])

In [60]:
#out = df.merge(df_match[["abs_base","ecg_emb_dl"]], on="abs_base", how="left")
out.to_csv(OUT_CSV, index=False)
print(f"Saved → {OUT_CSV}")

Saved → /content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/ecg_embeddings_500_sampling.csv


In [61]:
out.head()

Unnamed: 0,subject_id,study_id,waveform_path,hadm_id,insurance,race,hospital_expire_flag,record_id,emb_densenet,emb_resnet50,emb_efficientnet_b0,emb_convnext_tiny,emb_dinov2,emb_swin_small,emb_convnext_base,emb_effnetv2_s
0,10070614,45002050,files/p1007/p10070614/s45002050/45002050,21323771,Other,ASIAN - CHINESE,0,45002050,"[9.847338e-06, 0.00019438121, 1.414723e-05, 3....","[0.00018946252, 0.0, 0.084961526, 0.0, 0.0, 0....","[-0.008540902, -0.0108996015, -0.013241604, -0...","[-0.016765451, -0.0056122276, 0.018160976, -0....","[-0.08137009, 0.04834323, 0.03149641, -0.13997...","[0.017277956, 0.051369645, 0.0060650217, 0.050...","[-0.03200471, 0.0049697896, 0.037340123, 0.027...","[0.019792778, -0.00013387676, -0.0010038844, -..."
1,10073847,47598140,files/p1007/p10073847/s47598140/47598140,22194617,Private,OTHER,1,47598140,"[1.3890333e-05, 0.00030038515, 9.253712e-05, 4...","[0.0, 0.0, 0.01740081, 0.0, 0.0, 0.0, 0.000697...","[0.009660365, -0.011721137, -0.011255822, -0.0...","[-0.00017560489, -0.029622344, -0.0059531983, ...","[-0.09472636, 0.042605676, 0.033398956, -0.109...","[-0.0476419, 0.024252992, -0.02011218, 0.03354...","[-0.03992805, -0.009766443, 0.037131038, 0.025...","[0.0010481813, -0.0002127511, -0.001067375, -0..."
2,10073847,49452478,files/p1007/p10073847/s49452478/49452478,22194617,Private,OTHER,1,49452478,"[1.3033758e-05, 0.00041907365, 8.77906e-05, 6....","[0.0, 0.0, 0.0071130954, 0.0, 0.0, 0.0, 0.0009...","[0.04396475, -0.0056326697, -0.014598539, -0.0...","[-0.019081289, -0.030955313, -0.024518082, -0....","[-0.071752936, 0.04642795, 0.019282196, -0.084...","[-0.045597203, -0.001738442, -0.04495257, 0.00...","[-0.054610442, -0.0014507282, 0.03595155, 0.03...","[-0.0042709997, -0.00018842636, -0.0010085101,..."
3,10082560,41469879,files/p1008/p10082560/s41469879/41469879,23284776,Medicare,BLACK/AFRICAN AMERICAN,1,41469879,"[9.497244e-06, 0.0003953722, 4.765285e-05, 3.3...","[0.0, 0.0, 0.009498871, 0.0, 0.0, 0.0, 0.0, 0....","[0.0007222163, -0.012796574, -0.009576918, -0....","[-0.00040980947, -0.018155733, -0.011329171, -...","[-0.016805433, 0.061205372, 0.05951823, -0.090...","[-0.06258996, 0.009017544, -0.032825932, -0.00...","[-0.050609592, -0.012924857, 0.02932106, 0.031...","[-0.004848201, -0.00019937099, -0.00071365223,..."
4,10104450,45750532,files/p1010/p10104450/s45750532/45750532,23157316,Medicare,WHITE,1,45750532,"[9.847988e-06, 0.00016837237, 7.4051172e-06, 2...","[0.0, 0.0, 0.060870048, 0.0, 0.0, 0.0, 0.0, 0....","[-0.0138177825, -0.012096585, -0.009613659, -0...","[-0.0127146095, -0.044210553, -0.012751232, -0...","[-0.076176524, 0.02310867, 0.046015885, -0.092...","[-0.070743285, -0.006650425, 0.019672781, 0.01...","[-0.043136716, 0.01594901, 0.04405236, 0.03980...","[0.053601474, -6.9304144e-05, -0.00069896306, ..."
