In [2]:
#Cell 0 — Install + Mount
# Install deps (run once per runtime)
!pip -q install wfdb==4.1.2 numpy pandas scipy matplotlib pyarrow

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
# ===== EDIT THESE FILESYSTEM PATHS (not https) =====
WAVEFORM_ROOT = "/content/drive/MyDrive/PhysioNet_Data/mimic3wdb_data"  # e.g., .../mimic3wdb_data or .../mimic3wdb_data/30
OUTPUT_DIR    = "/content/drive/MyDrive/PhysioNet_Data"
# ===================================================

import os, re, pandas as pd, wfdb

os.makedirs(OUTPUT_DIR, exist_ok=True)
print("Root exists?", os.path.exists(WAVEFORM_ROOT))

def is_suffixed_header(fname: str) -> bool:
    # Keep only headers like 3000126_0014.hea; skip 3000126.hea and *_layout.hea
    return fname.endswith(".hea") and re.search(r"_\d+\.hea$", fname) and not fname.endswith("_layout.hea")

def parse_subject_id_from_path(path: str):
    # Prefer the subject folder name; else take leading digits of filename
    folder = os.path.basename(os.path.dirname(path))
    if folder.isdigit(): return int(folder)
    m = re.match(r"(\d+)", os.path.basename(path))
    return int(m.group(1)) if m else None

def guess_signal_from_header(base_no_ext: str) -> str:
    try:
        h = wfdb.rdheader(base_no_ext)
        names = [n.lower() for n in (h.sig_name or [])]
        if any(k in n for n in names for k in ("pleth","ppg","spo2","oxim","ppleth","pulse")): return "PPG"
        if any(("ecg" in n) or n.startswith(("ii","v1","v2","v3","v4","v5","v6")) for n in names): return "ECG"
    except Exception:
        pass
    b = os.path.basename(base_no_ext).lower()
    if any(k in b for k in ("pleth","ppg","spo2","oxim")): return "PPG"
    return "ECG"

rows, cnt = [], 0
for root, _, files in os.walk(WAVEFORM_ROOT):
    for f in files:
        if not is_suffixed_header(f):
            continue
        cnt += 1
        base_no_ext = os.path.join(root, f[:-4])  # drop .hea
        rows.append({
            "record_path": base_no_ext,
            "subject_id": parse_subject_id_from_path(base_no_ext),
            "signal": guess_signal_from_header(base_no_ext),
            "subject_dir": os.path.dirname(base_no_ext)  # for stitch mode
        })

print("Suffixed headers found:", cnt)
manifest = pd.DataFrame(rows).drop_duplicates().sort_values(["subject_id","record_path"])
MANIFEST_CSV = os.path.join(OUTPUT_DIR, "waveform_manifest.csv")
manifest.to_csv(MANIFEST_CSV, index=False)
print("Saved manifest ➜", MANIFEST_CSV)
manifest.head(10)


Root exists? True
Suffixed headers found: 371
Saved manifest ➜ /content/drive/MyDrive/PhysioNet_Data/waveform_manifest.csv


Unnamed: 0,record_path,subject_id,signal,subject_dir
0,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
1,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
2,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
3,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
4,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
5,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
6,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
7,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
8,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...
9,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...,3000003,ECG,/content/drive/MyDrive/PhysioNet_Data/mimic3wd...


In [6]:
import os, math, numpy as np, pandas as pd, wfdb
from scipy.signal import butter, filtfilt, resample

# ===== EDIT THESE PATHS =====
OUTPUT_DIR    = "/content/drive/MyDrive/PhysioNet_Data"
MANIFEST_CSV  = "/content/drive/MyDrive/PhysioNet_Data/waveform_manifest.csv"

# If you have clinical CSVs, put them here; else leave as-is & proceed (HF=-1)
CLINICAL_DIR  = "/content/drive/MyDrive/PhysioNet_Data/clinical"
PATIENTS_CSV  = f"{CLINICAL_DIR}/PATIENTS.csv"
ADMISSIONS_CSV= f"{CLINICAL_DIR}/ADMISSIONS.csv"
DIAG_ICD_CSV  = f"{CLINICAL_DIR}/DIAGNOSES_ICD.csv"
# ============================

# ===== Processing params (balanced for short records) =====
BANDS = {"ECG": (0.5, 40.0), "PPG": (0.3, 10.0)}
FILTER_ORDER_PRIMARY = 2
FILTER_ORDER_FALLBACK = 1
FS_TARGET = 125              # unify sampling for stitching & windowing
WINDOW_SEC = 5               # 3–5s works well for many short clips
OVERLAP = 0.5
NORMALIZE_PER_WINDOW = True
MIN_DURATION_SEC = 3.0       # ignore super tiny files before stitching
MAX_SECONDS_PER_SUBJECT = 180  # cap concat length per subject/signal (optional)
# ==========================================================

os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---- Clinical: optional ----
def try_load_clinical(patients_csv, admissions_csv, diag_csv):
    try:
        pts = pd.read_csv(patients_csv); pts.columns = [c.lower() for c in pts.columns]
        adm = pd.read_csv(admissions_csv); adm.columns = [c.lower() for c in adm.columns]
        dg  = pd.read_csv(diag_csv);      dg.columns  = [c.lower() for c in dg.columns]
        return pts, adm, dg
    except Exception as e:
        print("[INFO] Clinical CSVs not found or unreadable -> proceeding without labels.", e)
        return None, None, None

def build_demo(pts, adm, dg):
    if pts is None:
        return pd.DataFrame(columns=["subject_id","gender","age","ethnicity","HF"])
    # dates → age at first admit
    for df in (pts, adm):
        for col in ("dob","admittime"):
            if col in df.columns:
                df[col] = pd.to_datetime(df[col], errors="coerce")
    fa  = adm.sort_values("admittime").groupby("subject_id").first().reset_index()
    age = fa.merge(pts[["subject_id","dob"]], on="subject_id", how="left")
    age["age"] = ((age["admittime"] - age["dob"]).dt.days/365.25).clip(lower=0)
    age = age[["subject_id","age"]]
    eth = adm.groupby(["subject_id","ethnicity"]).size().reset_index(name="n")
    idx = eth.groupby("subject_id")["n"].idxmax()
    eth = eth.loc[idx, ["subject_id","ethnicity"]]
    dg = dg.copy(); dg["icd9_code"] = dg["icd9_code"].astype(str)
    dg["is_hf"] = dg["icd9_code"].str.startswith("428").astype(int)
    hf = dg.groupby("subject_id")["is_hf"].max().rename("HF").reset_index()
    demo = (pts[["subject_id","gender"]]
            .merge(age, on="subject_id", how="left")
            .merge(eth, on="subject_id", how="left")
            .merge(hf, on="subject_id", how="left"))
    demo["HF"] = demo["HF"].fillna(0).astype(int)
    return demo

pts, adm, dg = try_load_clinical(PATIENTS_CSV, ADMISSIONS_CSV, DIAG_ICD_CSV)
demo = build_demo(pts, adm, dg)

# ---- helpers ----
def safe_bandpass(x, fs, low, high, order_primary=2, order_fallback=1):
    def _try(x, fs, low, high, order):
        nyq = 0.5*fs
        b, a = butter(order, [low/nyq, high/nyq], btype="band")
        padlen = 3 * (max(len(a), len(b)) - 1)
        if len(x) <= padlen + 1: return None
        return filtfilt(b, a, x)
    y = _try(x, fs, low, high, order_primary)
    if y is not None: return y
    y = _try(x, fs, low, high, order_fallback)
    if y is not None: return y
    return x  # last resort: unfiltered

def resample_to(x, fs, fs_target):
    if fs_target is None or abs(fs - fs_target) < 1e-6:
        return x.astype(np.float32), float(fs)
    n_out = int(round(len(x) * (fs_target/fs)))
    if n_out < 2: return x.astype(np.float32), float(fs)
    return resample(x, n_out).astype(np.float32), float(fs_target)

def segment(x, fs, win_sec, overlap):
    win  = int(win_sec*fs)
    step = int(win*(1-overlap)) or win
    segs, starts = [], []
    for st in range(0, max(len(x)-win+1,0), step):
        s = x[st:st+win]
        if len(s)==win:
            segs.append(s); starts.append(st)
    return (np.stack(segs), starts) if segs else (np.empty((0,0)), [])

def zscore(arr, axis=None, eps=1e-8):
    m  = np.mean(arr, axis=axis, keepdims=True)
    sd = np.std(arr, axis=axis, keepdims=True)
    return (arr - m) / (sd + eps)

def load_header(base):
    h  = wfdb.rdheader(base)
    fs = float(h.fs)
    ln = int(h.sig_len)
    ch = h.sig_name[0] if (hasattr(h,"sig_name") and h.sig_name) else "ch0"
    return fs, ln, ch

def load_record(base):
    rec = wfdb.rdrecord(base)
    fs  = float(rec.fs)
    dat = rec.p_signal
    x   = dat[:,0] if dat.ndim==2 else dat
    ch  = rec.sig_name[0] if hasattr(rec,"sig_name") and rec.sig_name else "ch0"
    return x.astype(np.float32), fs, ch

# ---- read manifest & STITCH per subject + signal ----
man = pd.read_csv(MANIFEST_CSV)
need = {"record_path","subject_id","signal","subject_dir"}
if not need.issubset(man.columns):
    raise ValueError(f"Manifest missing columns: {need - set(man.columns)}")

# group by (subject_id, signal)
group_cols = ["subject_id","signal","subject_dir"]
groups = man.groupby(group_cols)

all_X, all_y, meta_rows = [], [], []
stats = {"groups": len(groups), "short_skipped":0, "concat_seconds":[], "ok_groups":0}

for (subj, sig, subj_dir), g in groups:
    # sort records by numeric suffix to preserve order
    def _suffix(s):
        m = re.search(r"_(\d+)$", os.path.basename(s))
        return int(m.group(1)) if m else 0
    bases = sorted(g["record_path"].tolist(), key=_suffix)

    x_concat = []
    fs_used  = FS_TARGET  # we will resample each file to FS_TARGET before concat
    ch_name  = None
    total_sec= 0.0

    for base in bases:
        try:
            fs_h, ln_h, ch_h = load_header(base)
        except Exception:
            continue
        dur = ln_h / max(fs_h,1e-6)
        if dur < MIN_DURATION_SEC:
            stats["short_skipped"] += 1
            continue
        try:
            x, fs, ch = load_record(base)
        except Exception:
            continue
        ch_name = ch_name or ch
        x, fs_new = resample_to(x, fs, FS_TARGET)
        x_concat.append(x)
        total_sec += len(x)/fs_new
        if MAX_SECONDS_PER_SUBJECT and total_sec >= MAX_SECONDS_PER_SUBJECT:
            break

    if not x_concat:
        continue

    x_cat = np.concatenate(x_concat, axis=0)
    x_cat = safe_bandpass(x_cat, FS_TARGET, *BANDS.get(sig.upper(),"ECG"),
                          order_primary=FILTER_ORDER_PRIMARY, order_fallback=FILTER_ORDER_FALLBACK)

    Xseg, starts = segment(x_cat, FS_TARGET, WINDOW_SEC, OVERLAP)
    if Xseg.size == 0:
        continue
    Xseg = zscore(Xseg, axis=1) if NORMALIZE_PER_WINDOW else zscore(Xseg, axis=None)

    # labels/demographics
    if (subj is not None) and (not demo.empty) and ((demo["subject_id"]==subj).any()):
        drow = demo.loc[demo["subject_id"]==subj].iloc[0]
        HF_label = int(drow["HF"])
        gender   = drow.get("gender", None)
        age      = float(drow["age"]) if pd.notna(drow.get("age", np.nan)) else None
        ethnicity= drow.get("ethnicity", None)
    else:
        HF_label = -1; gender=age=ethnicity=None

    all_X.append(Xseg)
    all_y.append(np.full((Xseg.shape[0],), HF_label, dtype=np.int8))
    for k, st in enumerate(starts):
        meta_rows.append({
            "subject_id": subj, "signal": sig, "channel": ch_name, "subject_dir": subj_dir,
            "segment_index": k, "start_sample": st, "fs": FS_TARGET,
            "gender": gender, "age": age, "ethnicity": ethnicity
        })

    stats["concat_seconds"].append(total_sec)
    stats["ok_groups"] += 1

# ---- save ----
if not all_X:
    raise RuntimeError("No segments produced. If many groups are still short, try WINDOW_SEC=3 and MIN_DURATION_SEC=2.")

X = np.concatenate(all_X, axis=0)
y = np.concatenate(all_y, axis=0)
meta = pd.DataFrame(meta_rows)

np.save(os.path.join(OUTPUT_DIR, "X_windows.npy"), X)
np.save(os.path.join(OUTPUT_DIR, "y_labels.npy"), y)
meta.to_csv(os.path.join(OUTPUT_DIR, "metadata.csv"), index=False)

print("Saved:")
print("  X_windows.npy", X.shape)
print("  y_labels.npy ", y.shape, "(HF=-1 means unlabeled/clinical CSVs missing)")
print("  metadata.csv ", len(meta))
print("Groups processed:", stats["ok_groups"], "/", stats["groups"])
if stats["concat_seconds"]:
    print("Median seconds per subject-signal (after stitching):",
          float(np.median(stats["concat_seconds"])))


[INFO] Clinical CSVs not found or unreadable -> proceeding without labels. [Errno 2] No such file or directory: '/content/drive/MyDrive/PhysioNet_Data/clinical/PATIENTS.csv'
Saved:
  X_windows.npy (97358, 625)
  y_labels.npy  (97358,) (HF=-1 means unlabeled/clinical CSVs missing)
  metadata.csv  97358
Groups processed: 30 / 31
Median seconds per subject-signal (after stitching): 1757.5


In [8]:
import numpy as np, pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader

X_PATH = "/content/drive/MyDrive/PhysioNet_Data/X_windows.npy"
Y_PATH = "/content/drive/MyDrive/PhysioNet_Data/y_labels.npy"
META   = "/content/drive/MyDrive/PhysioNet_Data/metadata.csv"

# Memory-map (no big RAM hit)
X = np.load(X_PATH, mmap_mode="r")     # [N, L]
y = np.load(Y_PATH, mmap_mode="r")     # [N]
meta = pd.read_csv(META)

# Diagnostics
vals, cnts = np.unique(y, return_counts=True)
print("Label distribution:", dict(zip(vals.tolist(), cnts.tolist())))
print("X shape:", X.shape)

# Case A: labeled windows available (0/1)
labeled_mask = (y >= 0)
if labeled_mask.sum() > 0:
    print(f"Labeled windows: {int(labeled_mask.sum())}")

    meta_lab = meta.loc[labeled_mask].reset_index(drop=True)
    subjects = meta_lab["subject_id"].fillna(-1).astype(int).values
    uniq = np.unique(subjects[subjects != -1])

    if uniq.size < 2:
        raise ValueError("Not enough distinct subjects with labels to split. "
                         "Check subject_id parsing and clinical CSV join.")

    train_s, temp = train_test_split(uniq, test_size=0.30, random_state=42)
    val_s, test_s = train_test_split(temp, test_size=0.50, random_state=42)

    lab_idx = np.where(labeled_mask)[0]
    train_idx = lab_idx[np.isin(subjects, train_s)]
    val_idx   = lab_idx[np.isin(subjects, val_s)]
    test_idx  = lab_idx[np.isin(subjects, test_s)]

else:
    # Case B: NO labels yet (all -1) → split by subject_id anyway (for pretraining/testing)
    print("No labeled windows found (y is all -1). "
          "Proceeding with subject-wise split WITHOUT labels.")
    subjects_all = meta["subject_id"].fillna(-1).astype(int).values
    uniq = np.unique(subjects_all[subjects_all != -1])
    if uniq.size < 2:
        raise ValueError("No valid subject_id values in metadata. "
                         "Fix manifest subject_id extraction and/or rerun preprocessing.")

    train_s, temp = train_test_split(uniq, test_size=0.30, random_state=42)
    val_s, test_s = train_test_split(temp, test_size=0.50, random_state=42)

    idx_all = np.arange(len(meta))
    train_idx = idx_all[np.isin(subjects_all, train_s)]
    val_idx   = idx_all[np.isin(subjects_all, val_s)]
    test_idx  = idx_all[np.isin(subjects_all, test_s)]

    # Create dummy labels so DataLoader works (FOR TESTING ONLY)
    y = np.zeros_like(y, dtype=np.int64)

print(f"Split sizes → train: {len(train_idx)}, val: {len(val_idx)}, test: {len(test_idx)}")

class MemmapDataset(Dataset):
    def __init__(self, X_memmap, y_arr, indices):
        self.X = X_memmap
        self.y = y_arr
        self.idx = np.asarray(indices)
    def __len__(self):
        return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = self.X[j][None, :]         # (1, L)  single-channel
        return torch.from_numpy(x).float(), int(self.y[j])

train_loader = DataLoader(MemmapDataset(X, y, train_idx), batch_size=256, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(MemmapDataset(X, y, val_idx),   batch_size=256, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(MemmapDataset(X, y, test_idx),  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

print("DataLoaders ready.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Label distribution: {-1: 97358}
X shape: (97358, 625)
No labeled windows found (y is all -1). Proceeding with subject-wise split WITHOUT labels.
Split sizes → train: 86003, val: 2402, test: 8953
DataLoaders ready.
