In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pandas as pd
import numpy as np

#ecg_afib_df=pd.read_csv('/content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/42713_Afib_ECG_adm_cohort.csv')
#ecg_afib_df.head()


#Signals embeddings

In [3]:
!pip -q install wfdb scipy pandas numpy tqdm torch_ecg --progress-bar off

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for sgmllib3k (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.[0m[31m
[0m

In [4]:
import os, numpy as np, pandas as pd, wfdb, torch, torchvision
from tqdm import tqdm
from scipy.signal import resample_poly
from torch import nn
from PIL import Image
import torchvision.transforms as T

In [5]:
WAVE_ROOT = "/content/drive/MyDrive/MIMIC_IV_Pipeline/ecg_signals/ecg_500_subset"     #path that contains signal files'
META_CSV  = "/content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/42713_Afib_ECG_adm_cohort.csv" #metadata for the specific cohort
OUT_CSV   = "/content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/ecg_embeddings_500_sampling.csv"#adding the output path

df = pd.read_csv(META_CSV).copy()

In [6]:
#def record_id_from_waveform_path(wp: str) -> str:
#    return os.path.basename(str(wp).strip("/"))
df["record_id"] = df["waveform_path"].astype(str).str.strip("/").str.split("/").str[-1]

In [7]:
def resolve_abs_base_from_record_id(rid: str):
    base = os.path.join(WAVE_ROOT, rid)
    if os.path.exists(base + ".hea") and os.path.exists(base + ".dat"):
        return base
    if os.path.exists(base + ".hea.gz") and os.path.exists(base + ".dat.gz"):
        return base
    return None

df["abs_base"] = df["record_id"].apply(resolve_abs_base_from_record_id)
print("Found local files for", df["abs_base"].notna().sum(), "of", len(df), "rows")

Found local files for 500 of 176153 rows


In [8]:
df_match = df[df["abs_base"].notna()].copy().reset_index(drop=True)

In [9]:
#take only with those that match with the downloaded signals
df_match = df[df["abs_base"].notna()].copy().reset_index(drop=True)#this abs_base is storing the absolute base path for the local ecg files (both- hea and dat)
if len(df_match) == 0:
    raise RuntimeError("No matching WFDB files were found. Check WAVE_ROOT and filenames.")


In [10]:
def load_wfdb_record(abs_base):
    rec = wfdb.rdrecord(abs_base)            #.hea adds .dat
    sig = rec.p_signal.astype(np.float32)    # (n_samples, n_leads)
    fs  = float(rec.fs)
    return sig, fs

def to_12_leads(sig):
    n, L = sig.shape
    if L == 12: return sig
    out = np.zeros((n, 12), dtype=np.float32)
    out[:, :min(L,12)] = sig[:, :min(L,12)]
    return out

def fix_length_resample(sig, fs, target_fs=500, target_sec=10):
    if abs(fs - target_fs) > 1e-3:
        sig = resample_poly(sig, up=int(target_fs), down=int(fs), axis=0)
    n_target = int(target_fs * target_sec)
    if sig.shape[0] >= n_target:
        s = (sig.shape[0] - n_target) // 2
        sig = sig[s:s+n_target, :]
    else:
        pad = n_target - sig.shape[0]
        left, right = pad // 2, pad - pad // 2
        sig = np.pad(sig, ((left, right), (0, 0)), mode="constant")
    return sig  # (5000, 12)

def ecg_to_image(sig_txL):
    # z-score per lead, clip, min-max to 0..255
    x = (sig_txL - sig_txL.mean(axis=0, keepdims=True)) / (sig_txL.std(axis=0, keepdims=True) + 1e-9)
    x = np.clip(x, -5, 5)
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) #Handle potential NaNs or Infs
    x = (x - x.min()) / (x.max() - x.min() + 1e-9)
    x = (x * 255.0).astype(np.uint8)         # (5000, 12)
    img = x.T                                # (12, 5000) → HxW
    pil = Image.fromarray(img).resize((224, 224), Image.BILINEAR).convert("RGB")
    return pil

#DenseNet feature extraction
device = "cuda" if torch.cuda.is_available() else "cpu"
backbone = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
feature_extractor = nn.Sequential(
    backbone.features,
    nn.ReLU(inplace=True),
    nn.AdaptiveAvgPool2d((1,1)),
    nn.Flatten(),                            # -> 1024-D
).to(device).eval()

img_tf = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225]),
])

# ------------------ RESNET50 ------------------
resnet = torchvision.models.resnet50(
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2
)
resnet_feature = nn.Sequential(
    *(list(resnet.children())[:-1]),   # remove FC → output (2048,1,1)
    nn.Flatten(),                      # → (2048,)
).to(device).eval()

# ---------------- EFFICIENTNET-B0 -------------
effnet = torchvision.models.efficientnet_b0(
    weights=torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1
)
effnet_feature = nn.Sequential(
    effnet.features,
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),                      # → (1280,)
).to(device).eval()

# ----------------- CONVNEXT-TINY -------------
convnext = torchvision.models.convnext_tiny(
    weights=torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
)
convnext_feature = nn.Sequential(
    convnext.features,
    nn.AdaptiveAvgPool2d((1,1)),
    nn.Flatten(),                      # → (768,)
).to(device).eval()

# ------------- MAIN EMBEDDING FUNCTION --------
def embed_with_all_models(pil_img):

    x = img_tf(pil_img)[None].to(device)

    out = {}

    with torch.no_grad():
        # DenseNet121
        f_dn = feature_extractor(x).squeeze()
        out["densenet"] = torch.nn.functional.normalize(f_dn, p=2, dim=0).cpu().numpy()

        # ResNet50
        f_rs = resnet_feature(x).squeeze()
        out["resnet50"] = torch.nn.functional.normalize(f_rs, p=2, dim=0).cpu().numpy()

        # EfficientNet-B0
        f_eff = effnet_feature(x).squeeze()
        out["efficientnet"] = torch.nn.functional.normalize(f_eff, p=2, dim=0).cpu().numpy()

        # ConvNeXt-Tiny
        f_cn = convnext_feature(x).squeeze()
        out["convnext"] = torch.nn.functional.normalize(f_cn, p=2, dim=0).cpu().numpy()

    return out


def embed_record(abs_base):
    sig, fs = load_wfdb_record(abs_base)
    sig = to_12_leads(sig)
    sig = fix_length_resample(sig, fs, 500, 10)
    pil = ecg_to_image(sig)

    feats = embed_with_all_models(pil)

    # return dictionary of numpy arrays
    return {
        "densenet":    feats["densenet"].tolist(),
        "resnet50":    feats["resnet50"].tolist(),
        "efficientnet": feats["efficientnet"].tolist(),
        "convnext":    feats["convnext"].tolist(),
    }


#mat->for maching
densenet_emb, resnet_emb, effnet_emb, convnext_emb = [], [], [], []
mat, miss = 0, 0
for _, row in tqdm(df_match.iterrows(), total=len(df_match), desc="ECG → embeddings (4 models)"):
    try:
        vecs = embed_record(row["abs_base"])
        densenet_emb.append(vecs["densenet"])
        resnet_emb.append(vecs["resnet50"])
        effnet_emb.append(vecs["efficientnet"])
        convnext_emb.append(vecs["convnext"])
        mat += 1
    except Exception as e:
        densenet_emb.append(None)
        resnet_emb.append(None)
        effnet_emb.append(None)
        convnext_emb.append(None)
        miss += 1

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


100%|██████████| 30.8M/30.8M [00:00<00:00, 167MB/s]


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 158MB/s]


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 153MB/s]


Downloading: "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" to /root/.cache/torch/hub/checkpoints/convnext_tiny-983f1562.pth


100%|██████████| 109M/109M [00:00<00:00, 175MB/s] 
ECG → embeddings (4 models): 100%|██████████| 500/500 [06:35<00:00,  1.26it/s]


In [11]:
df_match["ecg_emb_densenet"] = densenet_emb
df_match["ecg_emb_resnet50"] = resnet_emb
df_match["ecg_emb_efficientnet"] = effnet_emb
df_match["ecg_emb_convnext"] = convnext_emb

print(f"Embedded {mat}/{len(df_match)}; failed {miss}")

Embedded 500/500; failed 0


In [12]:
#keep only the columns needed for the merge (record_id + embedding)
# removed ecg_map creation and merge as df_match already contains the necessary data.

#Use df_match directly, which now includes the ecg_emb_dl column
out = df_match.copy()

out = out.drop(columns=["abs_base"])

In [13]:
#out = df.merge(df_match[["abs_base","ecg_emb_dl"]], on="abs_base", how="left")
out.to_csv(OUT_CSV, index=False)
print(f"Saved → {OUT_CSV}")

Saved → /content/drive/MyDrive/MIMIC_IV_Pipeline/data_cohorts/ecg_embeddings_500_sampling.csv


In [14]:
out.head()

Unnamed: 0,subject_id,study_id,waveform_path,hadm_id,insurance,race,hospital_expire_flag,record_id,ecg_emb_densenet,ecg_emb_resnet50,ecg_emb_efficientnet,ecg_emb_convnext
0,10070614,45002050,files/p1007/p10070614/s45002050/45002050,21323771,Other,ASIAN - CHINESE,0,45002050,"[9.847339242696762e-06, 0.00019438136951066554...","[0.00018946082855109125, 0.0, 0.08496165275573...","[-0.008540903218090534, -0.010899611748754978,...","[-0.016765514388680458, -0.005612209904938936,..."
1,10073847,47598140,files/p1007/p10073847/s47598140/47598140,22194617,Private,OTHER,1,47598140,"[1.3890318768972065e-05, 0.0003003852616529912...","[0.0, 0.0, 0.017400871962308884, 0.0, 0.0, 0.0...","[0.00966037530452013, -0.01172113511711359, -0...","[-0.00017570760974194854, -0.02962233498692512..."
2,10073847,49452478,files/p1007/p10073847/s49452478/49452478,22194617,Private,OTHER,1,49452478,"[1.3033754839852918e-05, 0.0004190737963654101...","[0.0, 0.0, 0.0071130674332380295, 0.0, 0.0, 0....","[0.04396482929587364, -0.005632673390209675, -...","[-0.01908138208091259, -0.03095528855919838, -..."
3,10082560,41469879,files/p1008/p10082560/s41469879/41469879,23284776,Medicare,BLACK/AFRICAN AMERICAN,1,41469879,"[9.497241990175098e-06, 0.00039537230622954667...","[0.0, 0.0, 0.009498875588178635, 0.0, 0.0, 0.0...","[0.0007222092826850712, -0.012796573340892792,...","[-0.00040991141577251256, -0.01815571449697017..."
4,10104450,45750532,files/p1010/p10104450/s45750532/45750532,23157316,Medicare,WHITE,1,45750532,"[9.847987712419126e-06, 0.00016837257135193795...","[0.0, 0.0, 0.06086995080113411, 0.0, 0.0, 0.0,...","[-0.013817787170410156, -0.012096579186618328,...","[-0.012714711017906666, -0.04421060532331467, ..."
