In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
!pwd

In [None]:
!nvidia-smi

# Imports

In [18]:
import csv
import os
import itertools
import json
import numpy as np
import random
from scipy import signal

import wfdb
import neurokit2 as nk

from utils import build_subject_drug_map

# Config

In [12]:
# ==== DATA PATHS ====
ECGRDVQ_ANNOTATION = 'data/ECGRDVQ/SCR-002.Clinical.Data.csv'
RAW_ROOT            = 'data/ECGRDVQ/raw'

# Outputs: two completely separate roots:
MEDIAN_ROOT_FULL    = 'data/ECGRDVQ_medians/method_A_full'     # one median per replicate
MEDIAN_ROOT_SUBSET  = 'data/ECGRDVQ_medians/method_B_subset'   # augmented medians per replicate
os.makedirs(MEDIAN_ROOT_FULL, exist_ok=True)
os.makedirs(MEDIAN_ROOT_SUBSET, exist_ok=True)

# ==== SAMPLE SELECTION ====
PLACEBO_TPT = 3.0

# ==== MEDIAN WAVEFORM SETTINGS ====
LEAD_INDEX            = 0           # Lead I
BEAT_PRE_SEC          = 0.25        # seconds before R
BEAT_POST_SEC         = 0.45        # seconds after R
MEDIAN_BEAT_LEN       = 512         # samples for each beat median (model I/O length)

# Guardrails (median derivation must use ≥ MIN_BEATS_HARD, prefer ≥ MIN_BEATS_PREF)
MIN_BEATS_HARD        = 3
MIN_BEATS_PREF        = 5

# METHOD B (subset augmentation) 
AUG_MODE              = 'exclude'   # 'include' or 'exclude'
INCLUDE_Y             = 7           # if mode == 'include': use Y beats (if available), else all
EXCLUDE_X             = 2          # if mode == 'exclude': leave out X beats
MAX_COMBINATIONS      = 64          # cap to avoid explosion; will randomly sample if exceeded
AUGMENT_SEED          = 42          # for reproducible sub-sampling of combinations

random.seed(AUGMENT_SEED)
np.random.seed(AUGMENT_SEED)


# 1. Core signal utilities (wavelet → R-peaks → beats → median)

In [13]:
def read_signal(filepath, lead_index=0):
    """
    filepath: path WITHOUT extension (wfdb base name). Returns 1-D numpy and fs.
    """
    sig, fields = wfdb.rdsamp(filepath, channels=[lead_index])
    fs = int(round(fields['fs']))
    return sig.flatten().astype(np.float32), fs


def segment_beats(sig, r_inds, fs, pre_sec=0.25, post_sec=0.45):
    """
    Extract beat segments around each R-peak with fixed pre/post seconds.
    Returns list of 1-D arrays (variable length), and list of indices used.
    """
    pre = int(round(pre_sec * fs))
    post = int(round(post_sec * fs))
    beats = []
    keep_idx = []
    for i, r in enumerate(r_inds):
        start = max(0, r - pre)
        end   = min(len(sig), r + post)
        # Require the full window to avoid biased edges
        if end - start == (pre + post):
            beats.append(sig[start:end])
            keep_idx.append(i)
    return beats, keep_idx


def resample_beats_to_fixed(beats, out_len):
    """
    Linearly resample each beat to out_len.
    """
    if len(beats) == 0:
        return np.empty((0, out_len), dtype=np.float32)
    mats = []
    for b in beats:
        mats.append(signal.resample(b, out_len))
    return np.stack(mats).astype(np.float32)   # [n_beats, out_len]


def median_of_beats(beats_mat):
    """
    beats_mat: [n_beats, L]
    returns: [L]
    """
    return np.median(beats_mat, axis=0).astype(np.float32)

# 2. Median waveform derivation method

In [14]:

def method_A_full_median(sig, fs):
    """
    Wavelet denoise → R-peaks → all beats → resample → ONE median (if enough beats).
    """
    den = nk.ecg_clean(sig)
    r = nk.ecg_findpeaks(den, sampling_rate=fs, method='neurokit', show=False)['ECG_R_Peaks']

    beats, idxs = segment_beats(den, r, fs, BEAT_PRE_SEC, BEAT_POST_SEC)
    if len(beats) < max(MIN_BEATS_PREF, MIN_BEATS_HARD):
        if len(beats) < MIN_BEATS_HARD:
            return None, {'n_beats': len(beats), 'used': 0, 'note': 'too_few_beats'}
        # if between hard and pref, still proceed
    mat = resample_beats_to_fixed(beats, MEDIAN_BEAT_LEN)
    med = median_of_beats(mat)
    meta = {'n_beats': len(beats), 'used': len(beats), 'note': 'full'}
    return med, meta


def _enumerate_subsets(n, mode='include', include_y=7, exclude_x=1):
    """
    Generate subsets of indices according to mode with guardrails.
    Caps count using MAX_COMBINATIONS; randomly subsamples if too many.
    """
    all_idx = np.arange(n)
    subsets = []
    if mode == 'include':
        k = include_y
        if n < k:
            # Not enough beats → must use all
            subsets = [tuple(all_idx)]
        else:
            subsets = list(itertools.combinations(all_idx, k))
    else:
        # exclude X: use n-X beats
        k = max(0, n - exclude_x)
        if k < MIN_BEATS_HARD:
            # keep at least MIN_BEATS_HARD
            k = max(MIN_BEATS_HARD, min(n, n - 1))  # at least 3, at most n-1
        if k >= n:
            subsets = [tuple(all_idx)]
        else:
            subsets = list(itertools.combinations(all_idx, k))

    # Ensure we’re not exploding
    if len(subsets) > MAX_COMBINATIONS:
        random.shuffle(subsets)
        subsets = subsets[:MAX_COMBINATIONS]
    return subsets


def method_B_subset_medians(sig, fs, mode='include', include_y=7, exclude_x=1):
    """
    Wavelet denoise → R-peaks → beats → MANY medians from subsets.
    Returns list of (median, info).
    """    
    den = nk.ecg_clean(sig)
    r = nk.ecg_findpeaks(den, sampling_rate=fs, method='neurokit', show=False)['ECG_R_Peaks']

    beats, idxs = segment_beats(den, r, fs, BEAT_PRE_SEC, BEAT_POST_SEC)
    n = len(beats)
    if n < MIN_BEATS_HARD:
        return [], {'n_beats': n, 'note': 'too_few_beats'}

    # Prefer at least MIN_BEATS_PREF; if not available, proceed with n
    subsets = _enumerate_subsets(n, mode=mode, include_y=include_y, exclude_x=exclude_x)

    outputs = []
    beats_mat = resample_beats_to_fixed(beats, MEDIAN_BEAT_LEN)  # [n, L]
    for sub in subsets:
        sub = np.array(sub, dtype=int)
        # Guardrail: ensure sub length
        if len(sub) < MIN_BEATS_HARD:
            continue
        med = np.median(beats_mat[sub], axis=0).astype(np.float32)
        info = {'n_beats': n, 'used': int(len(sub)), 'subset': sub.tolist(),
                'mode': mode, 'include_y': include_y, 'exclude_x': exclude_x}
        outputs.append((med, info))
        
    return outputs, {'n_beats': n, 'note': 'ok', 'count': len(outputs)}


# 3 Build (subject, drug) → replicates map (baseline vs post)


In [15]:
# see utils.build_subject_drug_map

# 4. Precompute & save medians (both methods) + manifests

In [16]:
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def precompute_for_id(egrefid, subj, drug, phase, out_root, method='A', aug_mode='include', include_y=7, exclude_x=1):
    """
    phase: 'baseline' or 'post'
    method: 'A' or 'B'
    Saves .npy files and returns list of manifest rows.
    """
    base = os.path.join(RAW_ROOT, str(subj), egrefid)   # wfdb base (no ext)
    if not (os.path.exists(base + '.hea') or os.path.exists(base + '.dat')):
        return []  # skip missing

    sig, fs = read_signal(base, lead_index=LEAD_INDEX)
    # METHOD A
    rows = []
    if method == 'A':
        med, info = method_A_full_median(sig, fs)
        if med is None:
            return []
        out_dir = os.path.join(out_root, f"{subj}", drug, phase)
        ensure_dir(out_dir)
        out_path = os.path.join(out_dir, f"{egrefid}_median.npy")
        np.save(out_path, med)
        rows.append({
            'subject': subj, 'drug': drug, 'phase': phase,
            'egrefid': egrefid, 'median_path': out_path,
            'method': 'A', 'n_beats': info['n_beats'], 'beats_used': info['used']
        })
        return rows

    # METHOD B
    outs, meta = method_B_subset_medians(sig, fs, mode=aug_mode, include_y=include_y, exclude_x=exclude_x)
    if not outs:
        return []
    out_dir = os.path.join(out_root, f"{subj}", drug, phase, egrefid)
    ensure_dir(out_dir)
    for k, (med, info) in enumerate(outs):
        out_path = os.path.join(out_dir, f"median_{k:04d}.npy")
        np.save(out_path, med)
        rows.append({
            'subject': subj, 'drug': drug, 'phase': phase,
            'egrefid': egrefid, 'median_path': out_path,
            'method': 'B', 'mode': info['mode'],
            'include_y': info.get('include_y', None),
            'exclude_x': info.get('exclude_x', None),
            'n_beats': info['n_beats'], 'beats_used': info['used'],
            'subset_idx': json.dumps(info.get('subset', []))
        })
    return rows


def precompute_all_medians(mapping, run_method_A=True, run_method_B=True):
    """
    Build two manifests: one for Method A (full) and one for Method B (subset).
    """
    man_A = []
    man_B = []
    for (subj, drug), info in mapping.items():
        for phase in ['baseline','post']:
            ids = info[phase]
            for egrefid in ids:
                if run_method_A:
                    rows = precompute_for_id(egrefid, subj, drug, phase, MEDIAN_ROOT_FULL, method='A')
                    man_A.extend(rows)
                if run_method_B:
                    rows = precompute_for_id(egrefid, subj, drug, phase, MEDIAN_ROOT_SUBSET, method='B',
                                             aug_mode=AUG_MODE, include_y=INCLUDE_Y, exclude_x=EXCLUDE_X)
                    man_B.extend(rows)

    # Save manifests
    if man_A:
        with open(os.path.join(MEDIAN_ROOT_FULL, 'manifest.csv'), 'w', newline='') as f:
            w = csv.DictWriter(f, fieldnames=man_A[0].keys())
            w.writeheader(); w.writerows(man_A)
    if man_B:
        with open(os.path.join(MEDIAN_ROOT_SUBSET, 'manifest.csv'), 'w', newline='') as f:
            w = csv.DictWriter(f, fieldnames=man_B[0].keys())
            w.writeheader(); w.writerows(man_B)
    print(f"Method A medians: {len(man_A)} entries; Method B medians: {len(man_B)} entries.")


In [19]:
mapping = build_subject_drug_map(ECGRDVQ_ANNOTATION, PLACEBO_TPT)
precompute_all_medians(mapping, run_method_A=True, run_method_B=True)

Method A medians: 654 entries; Method B medians: 30103 entries.


# 5. Visuzlize waveforms

In [20]:
# ==== Sanity plots for precomputed medians (Method A vs Method B) ====
import os, random, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

SANITY_PLOT_DIR = "median_sanity_plots"
os.makedirs(SANITY_PLOT_DIR, exist_ok=True)


def _load_manifest(path):
    if not os.path.exists(path):
        print(f"[warn] Manifest missing: {path}")
        return pd.DataFrame()
    try:
        df = pd.read_csv(path)
        return df
    except Exception as e:
        print(f"[warn] Failed to read manifest '{path}': {e}")
        return pd.DataFrame()


def _safe_name(s: str) -> str:
    return "".join(c if c.isalnum() or c in "-_." else "_" for c in str(s))


def plot_methodA_vs_methodB_examples(
    manA, manB, n_examples=4, max_B_curves=6, seed=42, save_dir=SANITY_PLOT_DIR
):
    """
    Pick random keys present in BOTH manifests and overlay Method A median
    with several Method B medians for the same (subject, drug, phase, EGREFID).
    """
    if manA.empty or manB.empty:
        print("[info] One of the manifests is empty — skipping A-vs-B overlay.")
        return

    random.seed(seed)
    # Build comparable keys
    key_cols = ["subject", "drug", "phase", "egrefid"]
    manA["_key"] = manA[key_cols].astype(str).agg("|".join, axis=1)
    manB["_key"] = manB[key_cols].astype(str).agg("|".join, axis=1)

    common_keys = sorted(set(manA["_key"]).intersection(set(manB["_key"])))
    if not common_keys:
        print("[info] No common (subject, drug, phase, EGREFID) between Method A and B.")
        return

    random.shuffle(common_keys)
    picked = common_keys[:n_examples]

    # Time axis in seconds for a beat window (resampled)
    t = np.linspace(-BEAT_PRE_SEC, BEAT_POST_SEC, MEDIAN_BEAT_LEN, endpoint=False)

    for idx, k in enumerate(picked):
        a_rows = manA[manA["_key"] == k]
        b_rows = manB[manB["_key"] == k]
        if a_rows.empty or b_rows.empty:
            continue

        # There should be a single Method A median per key; if multiple, take the first
        a_path = a_rows.iloc[0]["median_path"]
        try:
            a_wave = np.load(a_path).astype(np.float32)
        except Exception as e:
            print(f"[warn] Could not load Method A median '{a_path}': {e}")
            continue

        # Method B: many medians under the same key. Sample a few.
        b_paths = b_rows["median_path"].tolist()
        random.shuffle(b_paths)
        b_paths = b_paths[:max_B_curves]

        b_waves = []
        for pth in b_paths:
            try:
                b_waves.append(np.load(pth).astype(np.float32))
            except Exception as e:
                print(f"[warn] Could not load Method B median '{pth}': {e}")

        subj, drug, phase, egref = a_rows.iloc[0][["subject","drug","phase","egrefid"]].tolist()

        plt.figure(figsize=(10, 4))
        # Plot Method B medians (thin)
        for w in b_waves:
            if w.shape[0] != MEDIAN_BEAT_LEN:
                # length mismatch guard
                w = np.interp(np.linspace(0, len(w)-1, MEDIAN_BEAT_LEN), np.arange(len(w)), w)
            plt.plot(t, w, linewidth=1.0, alpha=0.6)
        # Plot Method A median (bold)
        if a_wave.shape[0] != MEDIAN_BEAT_LEN:
            a_wave = np.interp(np.linspace(0, len(a_wave)-1, MEDIAN_BEAT_LEN), np.arange(len(a_wave)), a_wave)
        plt.plot(t, a_wave, linewidth=2.4, color="black", label="Method A (full)")
        plt.title(f"(A vs B) subj {subj} • {drug} • {phase} • EGREFID {egref}")
        plt.xlabel("Time (s, aligned to R-peak)")
        plt.ylabel("Amplitude (arb. units)")
        plt.legend(loc="upper right")
        fname = os.path.join(save_dir, f"A_vs_B_subj{subj}_{_safe_name(drug)}_{phase}_EGREFID_{_safe_name(egref)}_{idx}.png")
        plt.savefig(fname, bbox_inches="tight")
        plt.close()

    print(f"[ok] Saved A-vs-B overlay sanity plots to: {save_dir}")

def plot_baseline_vs_post_methodA(
    manA, n_pairs=4, seed=123, save_dir=SANITY_PLOT_DIR
):
    """
    For random (subject, drug) pairs, plot baseline vs post medians (Method A).
    """
    if manA.empty:
        print("[info] Method A manifest is empty — skipping baseline vs post.")
        return

    random.seed(seed)
    # Group by (subject, drug, phase)
    g = manA.groupby(["subject","drug","phase"])
    # Build available baselines and posts
    keys_baseline = set(g.groups.keys()) & set((s,d,"baseline") for (s,d,_) in g.groups.keys())
    keys_post     = set(g.groups.keys()) & set((s,d,"post")     for (s,d,_) in g.groups.keys())

    # Find (subject,drug) that have both
    pairs = []
    for (s,d,ph) in keys_baseline:
        if (s,d,"post") in keys_post:
            pairs.append((s,d))
    pairs = sorted(set(pairs))
    if not pairs:
        print("[info] No (subject, drug) pairs with both baseline & post in Method A.")
        return

    random.shuffle(pairs)
    picks = pairs[:n_pairs]

    t = np.linspace(-BEAT_PRE_SEC, BEAT_POST_SEC, MEDIAN_BEAT_LEN, endpoint=False)

    for idx, (s,d) in enumerate(picks):
        # take first baseline & first post median we find for that (s,d)
        a_base = manA[(manA["subject"]==s) & (manA["drug"]==d) & (manA["phase"]=="baseline")]
        a_post = manA[(manA["subject"]==s) & (manA["drug"]==d) & (manA["phase"]=="post")]
        if a_base.empty or a_post.empty:
            continue
        base_path = a_base.iloc[0]["median_path"]
        post_path = a_post.iloc[0]["median_path"]
        try:
            w_base = np.load(base_path).astype(np.float32)
            w_post = np.load(post_path).astype(np.float32)
        except Exception as e:
            print(f"[warn] Could not load A medians for subj {s}, drug {d}: {e}")
            continue

        if w_base.shape[0] != MEDIAN_BEAT_LEN:
            w_base = np.interp(np.linspace(0, len(w_base)-1, MEDIAN_BEAT_LEN), np.arange(len(w_base)), w_base)
        if w_post.shape[0] != MEDIAN_BEAT_LEN:
            w_post = np.interp(np.linspace(0, len(w_post)-1, MEDIAN_BEAT_LEN), np.arange(len(w_post)), w_post)

        plt.figure(figsize=(10,4))
        plt.plot(t, w_base, label="Baseline (A)", color="tab:blue")
        plt.plot(t, w_post, label="Post (A)",     color="tab:red", linestyle="--")
        plt.title(f"(A) Baseline vs Post • subj {s} • {d}")
        plt.xlabel("Time (s, aligned to R-peak)")
        plt.ylabel("Amplitude (arb. units)")
        plt.legend(loc="upper right")
        fname = os.path.join(save_dir, f"A_baseline_vs_post_subj{s}_{_safe_name(d)}_{idx}.png")
        plt.savefig(fname, bbox_inches="tight")
        plt.close()

    print(f"[ok] Saved baseline-vs-post (Method A) sanity plots to: {save_dir}")

# ---- Run sanity plots ----
manA_path = os.path.join(MEDIAN_ROOT_FULL,   "manifest.csv")
manB_path = os.path.join(MEDIAN_ROOT_SUBSET, "manifest.csv")
manA = _load_manifest(manA_path)
manB = _load_manifest(manB_path)

plot_methodA_vs_methodB_examples(manA, manB, n_examples=12, max_B_curves=6, seed=123)
plot_baseline_vs_post_methodA(manA, n_pairs=12, seed=7)


[ok] Saved A-vs-B overlay sanity plots to: median_sanity_plots
[ok] Saved baseline-vs-post (Method A) sanity plots to: median_sanity_plots
