In [1]:
# ================================================================
# TabGraphSyn (GCN → VAE → latent DDPM) — ACTG175 (AIDS)
# Source: UCI via ucimlrepo ONLY (id=890) — no local CSV
#
# Goal: Lower marginal distribution error (KS%) without overfitting
# Changes:
#   • VAE training: KL warm-up + Marginal Distribution Loss (batch-quantile L1)
#     + Moment loss (mean/std alignment) to better match 1D marginals
#   • Post-decode calibration (holdout-aware):
#       - Binary cols: top-k rounding with prevalence shrinkage
#       - Continuous cols: partial rank-preserving quantile mapping (α_marg)
#   • Metrics now computed AFTER all calibrations (fixes earlier time-metric drift)
#
# Run:
#   pip install -U ucimlrepo numpy pandas torch scikit-learn scipy matplotlib umap-learn statsmodels
#   python tabgraphsyn_aids_ucirepo.py
#   (or) python tabgraphsyn_aids_ucirepo.py --uci_id 890
# ================================================================

import os, math, random, argparse, warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, accuracy_score
from scipy.stats import ks_2samp
import matplotlib.pyplot as plt
try:
    from umap import UMAP
except Exception:
    import umap.umap_ as _umap
    UMAP = _umap.UMAP

# Survival (NO lifelines): statsmodels
from statsmodels.duration.survfunc import SurvfuncRight
from statsmodels.duration.hazard_regression import PHReg
import statsmodels.api as sm

warnings.filterwarnings("ignore", category=UserWarning)

# ---------------------------
# Reproducibility & device
# ---------------------------
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# Output dir
# ---------------------------
OUTDIR = os.path.abspath("outputs_aids")
os.makedirs(OUTDIR, exist_ok=True)

# ---------------------------
# TUNING KNOBS (KS% ↓)
# ---------------------------
# VAE training
VAE_Z        = 40
VAE_H        = 320
VAE_LR       = 2e-3
VAE_WD       = 5e-4
VAE_EPOCHS   = 160        # ↑ a bit for better marginals
VAE_BS       = 128
BETA_MAX     = 0.004      # final KL weight; warm-up from 0 → BETA_MAX
MDE_W        = 0.08       # weight for marginal-quantile loss (per-feature)
MOMENT_W     = 0.12       # weight for mean/std alignment

# GCN
GCN_EPOCHS   = 120
GCN_LR       = 1e-3
GCN_WD       = 1e-3
KNN_K        = 10

# DDPM
TSTEPS       = 450
EPS_LR       = 2e-3
EPS_WD       = 1e-4
DDPM_EPOCHS  = 160
COND_NOISE   = 0.02
DROP_UNCOND  = 0.10
CFG_W        = 2.2
TAU_NOISE    = 0.95

# Post-decoding alignment
CORAL_REG    = 2e-3
CORAL_BLEND  = 0.85       # 0=no CORAL, 1=full CORAL, blend in standardized space
JITTER_SD    = 0.003

# Binary calibration (after inverse-transform, excluding event/time)
BIN_SHRINK   = 0.65       # convex blend toward real prevalence (0=no change, 1=match real)
# Marginal quantile blend (non-binary, excluding time/event)
ALPHA_MARG   = 0.55       # 0=no mapping, 1=full quantile map to real calib
HOLDOUT_MARG = 0.20       # portion of real data NOT used for marginal calibr.

# Survival calibration
LAMBDA_RATE  = 0.85       # shrink event rate toward real
ALPHA_TIME   = 0.85       # partial quantile map on times
HOLDOUT_SURV = 0.10       # bigger holdout reduces overfitting

# ---------------------------
# Models
# ---------------------------
class GCN(nn.Module):
    def __init__(self, in_dim, hidden=96, emb=48, classes=2, dropout=0.15):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden)
        self.fc2 = nn.Linear(hidden, emb)
        self.cls = nn.Linear(emb, classes)
        self.do = nn.Dropout(dropout)
        self.n1 = nn.LayerNorm(hidden)
        self.n2 = nn.LayerNorm(emb)
    def layer(self, A, X, W, N, act=True):
        H = A @ X
        H = self.do(W(H))
        H = N(H)
        return F.relu(H) if act else H
    def forward(self, A, X):
        h1 = self.layer(A, X, self.fc1, self.n1, True)
        h2 = self.layer(A, h1, self.fc2, self.n2, True)
        return self.cls(h2), h2

class VAE(nn.Module):
    def __init__(self, in_dim, z_dim=40, hidden=320):
        super().__init__()
        self.e1 = nn.Linear(in_dim, hidden)
        self.e2 = nn.Linear(hidden, hidden)
        self.mu = nn.Linear(hidden, z_dim)
        self.lv = nn.Linear(hidden, z_dim)
        self.d1 = nn.Linear(z_dim, hidden)
        self.d2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, in_dim)
        self.bn_e1 = nn.LayerNorm(hidden)
        self.bn_e2 = nn.LayerNorm(hidden)
        self.bn_d1 = nn.LayerNorm(hidden)
        self.bn_d2 = nn.LayerNorm(hidden)
        self.do = nn.Dropout(0.05)
    def encode(self, x):
        h = F.relu(self.bn_e1(self.e1(x)))
        h = F.relu(self.bn_e2(self.do(self.e2(h))))
        return self.mu(h), self.lv(h)
    def reparam(self, mu, logv):
        std = torch.exp(0.5*logv)
        return mu + std * torch.randn_like(std)
    def decode(self, z):
        h = F.relu(self.bn_d1(self.d1(z)))
        h = F.relu(self.bn_d2(self.do(self.d2(h))))
        return self.out(h)
    def forward(self, x, beta=1.0):
        mu, lv = self.encode(x)
        z = self.reparam(mu, lv)
        rec = self.decode(z)
        # MSE recon + KL; extra losses handled outside for batch-level stats
        rl = F.mse_loss(rec, x, reduction='mean')
        kld = -0.5 * torch.mean(1 + lv - mu.pow(2) - lv.exp())
        return rec, mu, lv, z, rl, kld

def t_embedding(t, dim=64, T=450):
    half = dim // 2
    freqs = torch.exp(torch.linspace(math.log(1.0), math.log(10000.0), steps=half, device=device))
    args = (t[:, None].float() / T) * freqs[None, :]
    return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

class EpsNet(nn.Module):
    def __init__(self, z_dim, c_dim, t_dim=64, hidden=288):
        super().__init__()
        self.t_dim = t_dim
        self.fc1 = nn.Linear(z_dim + c_dim + t_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, z_dim)
        self.do  = nn.Dropout(0.05)
        self.nrm = nn.LayerNorm(hidden)
    def forward(self, x, c, t, T=450):
        te = t_embedding(t, self.t_dim, T)
        h = torch.cat([x, c, te], 1)
        h = self.do(F.silu(self.fc1(h)))
        h = self.do(F.silu(self.nrm(self.fc2(h))))
        return self.fc3(h)

# ---------------------------
# Metrics
# ---------------------------
def marginal_KS_percent(real_std, synth_std, feat_names):
    rows = []
    for j in range(real_std.shape[1]):
        r = real_std[:, j]; s = synth_std[:, j]
        ks = ks_2samp(r, s, alternative='two-sided', mode='asymp').statistic
        rows.append({"feature": feat_names[j], "KS%": 100.0*float(ks)})
    df = pd.DataFrame(rows)
    return df, float(df["KS%"].mean())

def pairwise_corr_error_percent(real_std, synth_std, feat_names):
    rows = []
    p = real_std.shape[1]
    for i in range(p):
        for j in range(i+1, p):
            r_corr = np.corrcoef(real_std[:, [i, j]].T)[0,1]
            s_corr = np.corrcoef(synth_std[:, [i, j]].T)[0,1]
            rows.append({"feat_i": feat_names[i], "feat_j": feat_names[j], "|Δρ|%": 100.0*float(abs(s_corr - r_corr))})
    df = pd.DataFrame(rows)
    return df, float(df["|Δρ|%"].mean())

def detection_score_logistic(real_std, synth_std):
    n = min(len(real_std), len(synth_std))
    X = np.vstack([real_std[:n], synth_std[:n]])
    y = np.hstack([np.zeros(n, dtype=int), np.ones(n, dtype=int)])
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.3, random_state=SEED, stratify=y)
    clf = LogisticRegression(max_iter=5000, solver="lbfgs")
    clf.fit(Xtr, ytr)
    yhat = clf.predict(Xte)
    acc = accuracy_score(yte, yhat)
    auc = roc_auc_score(yte, clf.predict_proba(Xte)[:,1])
    return 100.0*acc, float(auc)

def detection_score_bestof(real_std, synth_std, seed=123):
    n = min(len(real_std), len(synth_std))
    X = np.vstack([real_std[:n], synth_std[:n]])
    y = np.hstack([np.zeros(n, dtype=int), np.ones(n, dtype=int)])
    models = {
        "LogReg(L2)": LogisticRegression(max_iter=5000, solver="lbfgs"),
        "SVM-RBF":    SVC(kernel="rbf", C=10.0, gamma="scale", probability=True, random_state=seed),
        "RF-400":     RandomForestClassifier(n_estimators=400, n_jobs=-1, random_state=seed),
    }
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    best_acc, best_auc, best_name = -1, -1, None
    for name, clf in models.items():
        accs, aucs = [], []
        for tr, te in skf.split(X, y):
            clf.fit(X[tr], y[tr])
            p = clf.predict_proba(X[te])[:, 1]
            accs.append(accuracy_score(y[te], (p >= 0.5).astype(int)))
            aucs.append(roc_auc_score(y[te], p))
        acc, auc = 100*np.mean(accs), float(np.mean(aucs))
        if acc > best_acc:
            best_acc, best_auc, best_name = acc, auc, name
    return best_acc, best_auc, best_name

# ---------------------------
# Plot helpers
# ---------------------------
def save_corr_heatmaps_with_labels(real_std, synth_std, feat_names, tag):
    C_real  = np.corrcoef(real_std, rowvar=False)
    C_synth = np.corrcoef(synth_std, rowvar=False)
    C_diff  = C_synth - C_real
    fig, axs = plt.subplots(1,3, figsize=(18,6))
    ims = [
        axs[0].imshow(C_real,  vmin=-1, vmax=1),
        axs[1].imshow(C_synth, vmin=-1, vmax=1),
        axs[2].imshow(C_diff,  vmin=-0.5, vmax=0.5),
    ]
    for ax, title in zip(axs, ["Corr: Real", "Corr: Synthetic", "Corr diff (S-R)"]):
        ax.set_title(title)
        ax.set_xticks(range(len(feat_names))); ax.set_yticks(range(len(feat_names)))
        ax.set_xticklabels(feat_names, rotation=90, fontsize=7)
        ax.set_yticklabels(feat_names, fontsize=7)
    for im, ax in zip(ims, axs):
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
    plt.suptitle(f"Correlation matrices — {tag}")
    plt.tight_layout()
    fname = f"{OUTDIR}/corr_heatmaps_{tag.lower()}_labeled.png"
    plt.savefig(fname, dpi=220); plt.close()
    return fname

def save_table_as_image(df, title, fname):
    fig, ax = plt.subplots(figsize=(8.0, 1.6 + 0.35 * len(df)))
    ax.axis('off')
    ax.set_title(title, fontsize=12, pad=10)
    tbl = ax.table(cellText=df.values, colLabels=df.columns, loc='center', cellLoc='center')
    tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1.0, 1.25)
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    plt.tight_layout(); plt.savefig(fname, dpi=220); plt.close()

def save_umap_paperstyle(real_std, synth_std, tag, panel_label="(b) AIDS", xlim=(-10,15), ylim=(-5,15)):
    reducer = UMAP(
        n_neighbors=5, min_dist=0.15, metric="euclidean",
        init="spectral", random_state=SEED, transform_seed=SEED
    )
    R2 = reducer.fit_transform(real_std)    # fit on REAL only
    S2 = reducer.transform(synth_std)       # transform SYNTH

    def _affine_to_bounds(col, new_min, new_max):
        cmin, cmax = float(col.min()), float(col.max())
        scale = (new_max - new_min) / (cmax - cmin + 1e-12)
        shift = new_min - cmin * scale
        return col * scale + shift

    R2s = np.column_stack([
        _affine_to_bounds(R2[:,0], xlim[0], xlim[1]),
        _affine_to_bounds(R2[:,1], ylim[0], ylim[1])
    ])
    S2s = np.column_stack([
        _affine_to_bounds(S2[:,0], xlim[0], xlim[1]),
        _affine_to_bounds(S2[:,1], ylim[0], ylim[1])
    ])

    fig = plt.figure(figsize=(6.2,5.0))
    plt.scatter(R2s[:,0], R2s[:,1], s=10, marker=".", alpha=0.9, label="Real")
    plt.scatter(S2s[:,0], S2s[:,1], s=10, marker=".", alpha=0.9, label="Synthetic (TabGraphSyn)")
    plt.legend(loc="lower left", frameon=True)
    plt.xlim(*xlim); plt.ylim(*ylim)
    ax = plt.gca()
    ax.set_xticks(np.linspace(xlim[0], xlim[1], 6))
    fig.subplots_adjust(bottom=0.18)
    fig.text(0.5, 0.06, panel_label, ha="center", va="center", fontsize=14)
    plt.tight_layout()
    fname = f"{OUTDIR}/umap_{tag.lower()}_paperstyle.png"
    plt.savefig(fname, dpi=200); plt.close()
    return fname

# ---------------------------
# Column aliasing & UCI loader
# ---------------------------
DROP_ID_LIKE = {"pidnum"}
ALIASES = {
    "time":  ["time","days","t","futime","survt"],
    "event": ["cid","cens","event","status","death","fail","delta","died"],
    "treat": ["trt","treat","arm","rx"],
}

def _clean_key(s: str) -> str:
    s = s.lower().strip()
    return "".join(ch for ch in s if ch.isalnum() or ch == "_")

def _pick_first_present(cols, candidates):
    for n in candidates:
        if n in cols: return n
    simp_map = {_clean_key(c): c for c in cols}
    for n in candidates:
        k = _clean_key(n)
        if k in simp_map: return simp_map[k]
    return None

def _to_numeric_df(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    for c in out.columns:
        if pd.api.types.is_numeric_dtype(out[c]):
            continue
        try:
            out[c] = out[c].astype("category").cat.codes
        except Exception:
            out[c] = pd.to_numeric(out[c], errors="coerce")
    return out

def load_actg175_from_ucirepo(uci_id: int = 890):
    from ucimlrepo import fetch_ucirepo
    ds = fetch_ucirepo(id=uci_id)

    try:
        print("=== UCI metadata (short) ===")
        print({"name": ds.metadata.get("name")})
        print("=== Variables (head) ===")
        if hasattr(ds, "variables"):
            print(ds.variables.head())
    except Exception:
        pass

    Xdf = ds.data.features.copy()
    Ydf = ds.data.targets.copy() if hasattr(ds.data, "targets") and ds.data.targets is not None else pd.DataFrame()
    raw = pd.concat([Xdf.reset_index(drop=True), Ydf.reset_index(drop=True)], axis=1)

    raw = raw[[c for c in raw.columns if c not in DROP_ID_LIKE]]
    raw = raw.loc[:, ~raw.columns.duplicated()]
    raw = _to_numeric_df(raw)

    cols = list(raw.columns)
    time_src  = _pick_first_present(cols, ALIASES["time"])
    event_src = _pick_first_present(cols, ALIASES["event"])
    treat_src = _pick_first_present(cols, ALIASES["treat"])
    if time_src is None:  raise KeyError(f"No time column found in {cols} (tried {ALIASES['time']})")
    if event_src is None: raise KeyError(f"No event column found in {cols} (tried {ALIASES['event']})")
    if treat_src is None: raise KeyError(f"No treat column found in {cols} (tried {ALIASES['treat']})")

    df = raw.copy()
    df["time"]  = pd.to_numeric(df[time_src],  errors="coerce")
    df["event"] = pd.to_numeric(df[event_src], errors="coerce")
    df["treat"] = pd.to_numeric(df[treat_src], errors="coerce")

    to_drop = set(ALIASES["time"] + ALIASES["event"] + ALIASES["treat"])
    to_drop.discard("time"); to_drop.discard("event"); to_drop.discard("treat")
    to_drop = [c for c in to_drop if c in df.columns]
    df = df.drop(columns=to_drop, errors="ignore")

    df = df.dropna(subset=["time","event"]).copy()
    med_t = float(np.nanmedian(df["time"].values))
    df.loc[df["time"] <= 0, "time"] = med_t if med_t > 0 else 1.0
    df["event"] = (df["event"] > 0).astype(int)
    df["treat"] = (df["treat"] > 0).astype(int)

    for c in df.columns:
        if pd.api.types.is_numeric_dtype(df[c]):
            df[c] = df[c].fillna(df[c].median())

    preferred_order = [
        "age","wtkg","hemo","homo","drugs","karnof","oprior","z30","zprior","preanti",
        "race","gender","str2","strat","symptom","offtrt","cd40","cd420","cd80","cd820",
        "time","treat","event"
    ]
    ordered = [c for c in preferred_order if c in df.columns]
    remaining = [c for c in df.columns if c not in ordered]
    df = df[ordered + remaining]
    df = df.loc[:, ~df.columns.duplicated()]

    feat_cols = [c for c in df.columns if c != "event"]  # keep time & treat as features
    X = df[feat_cols].copy()
    T = df["time"].astype(float).values
    E = df["event"].astype(int).values
    feat_names = list(feat_cols)

    binaries = []
    for c in df.columns:
        arr = df[c].to_numpy()
        if arr.ndim > 1: arr = arr.ravel()
        uniq = np.unique(arr[~np.isnan(arr)]) if issubclass(arr.dtype.type, np.floating) else np.unique(arr)
        if set(uniq.tolist()).issubset({0,1}):
            binaries.append(c)
    if "treat" not in binaries: binaries.append("treat")
    if "event" not in binaries: binaries.append("event")

    return df, X, T, E, feat_names, binaries, {"time":"time","event":"event","treat":"treat"}, "AIDS"

# ---------------------------
# Survival helpers + calibration (shrinkage + holdout)
# ---------------------------
def build_km_by_treatment(df, time_col, event_col, treat_col):
    out = {}
    xmax = float(np.nanmax(df[time_col].values))
    for g in sorted(df[treat_col].unique()):
        m = (df[treat_col].astype(int) == int(g)).to_numpy()
        if m.sum() == 0:
            continue
        sf = SurvfuncRight(df.loc[m, time_col].to_numpy(),
                           df.loc[m, event_col].to_numpy())
        t = np.asarray(sf.surv_times, dtype=float)
        s = np.asarray(sf.surv_prob,  dtype=float)
        t = np.r_[0.0, t]; s = np.r_[1.0, s]
        keep = t <= xmax
        t, s = t[keep], s[keep]
        if len(t) and t[-1] < xmax:
            t = np.r_[t, xmax]; s = np.r_[s, s[-1]]
        out[int(g)] = (t, s)
    return out, xmax

def _quantile_map_to_ref(source_vals, ref_vals):
    s = np.asarray(source_vals, float)
    r = np.asarray(ref_vals, float)
    if len(s) == 0:
        return s.copy()
    order = np.argsort(s)
    q = (np.arange(len(s)) + 0.5) / len(s)
    try:
        tgt_sorted = np.quantile(r, q, method="linear")
    except TypeError:
        tgt_sorted = np.quantile(r, q, interpolation="linear")
    out = np.empty_like(s)
    out[order] = tgt_sorted
    return out

def survival_calibrate_times_events_shrinkage(
    real_df, synth_df, time_col="time", event_col="event", treat_col="treat",
    lambda_rate=LAMBDA_RATE, alpha_time=ALPHA_TIME, holdout_frac=HOLDOUT_SURV, seed=SEED
):
    rng = np.random.default_rng(seed)
    S = synth_df.copy()
    S[treat_col] = (S[treat_col] > 0).astype(int)
    real_tmp = real_df.copy()
    real_tmp[treat_col] = (real_tmp[treat_col] > 0).astype(int)

    # Stratified holdout for calibration
    idx = np.arange(len(real_tmp))
    tr_idx, te_idx = [], []
    for g in sorted(real_tmp[treat_col].unique()):
        for e in [0,1]:
            mask = (real_tmp[treat_col]==g) & (real_tmp[event_col]==e)
            ids = idx[mask.to_numpy()]
            rng.shuffle(ids)
            k = int(round((1.0-holdout_frac)*len(ids)))
            tr_idx.extend(ids[:k]); te_idx.extend(ids[k:])
    calib = real_tmp.iloc[tr_idx].reset_index(drop=True)
    xmax = float(np.nanmax(real_tmp[time_col].values))

    for g in sorted(S[treat_col].unique()):
        g = int(g)
        m_s = (S[treat_col] == g).to_numpy()
        if m_s.sum() == 0:
            continue

        Rg_ev = calib[(calib[treat_col]==g) & (calib[event_col]==1)][time_col].values
        Rg_ce = calib[(calib[treat_col]==g) & (calib[event_col]==0)][time_col].values
        if len(Rg_ev)==0 and len(Rg_ce)==0:
            continue

        syn_rate = float(S.loc[m_s, event_col].mean()) if event_col in S.columns else 0.0
        real_rate = float(calib.loc[calib[treat_col]==g, event_col].mean()) if (calib[treat_col]==g).any() else syn_rate
        target_rate = (1.0 - lambda_rate) * syn_rate + lambda_rate * real_rate
        n_s = int(m_s.sum())
        n_ev_target = int(round(target_rate * n_s))
        n_ev_target = max(0, min(n_ev_target, n_s))

        idx_grp = np.where(m_s)[0]
        ev_now  = idx_grp[S.loc[idx_grp, event_col].astype(int).to_numpy()==1] if event_col in S.columns else np.array([], dtype=int)
        ce_now  = idx_grp[S.loc[idx_grp, event_col].astype(int).to_numpy()==0] if event_col in S.columns else idx_grp

        # adjust counts toward target
        if len(ev_now) > n_ev_target:
            rng.shuffle(ev_now); flip = ev_now[:len(ev_now)-n_ev_target]
            S.loc[flip, event_col] = 0
            ce_now = np.concatenate([ce_now, flip])
            ev_now = np.setdiff1d(ev_now, flip, assume_unique=False)
        elif len(ev_now) < n_ev_target:
            need = n_ev_target - len(ev_now)
            rng.shuffle(ce_now)
            take = ce_now[:need]
            S.loc[take, event_col] = 1
            ev_now = np.concatenate([ev_now, take])
            ce_now = np.setdiff1d(ce_now, take, assume_unique=False)

        # partial quantile mapping
        if len(ev_now) and len(Rg_ev):
            boot_ev = rng.choice(Rg_ev, size=len(ev_now), replace=True)
            mapped = _quantile_map_to_ref(S.loc[ev_now, time_col].values, boot_ev)
            S.loc[ev_now, time_col] = (1.0 - alpha_time) * S.loc[ev_now, time_col].values + alpha_time * mapped
        if len(ce_now) and len(Rg_ce):
            boot_ce = rng.choice(Rg_ce, size=len(ce_now), replace=True)
            mapped = _quantile_map_to_ref(S.loc[ce_now, time_col].values, boot_ce)
            S.loc[ce_now, time_col] = (1.0 - alpha_time) * S.loc[ce_now, time_col].values + alpha_time * mapped

    S[time_col] = np.clip(S[time_col].astype(float).values, 0.0, xmax)
    S[event_col] = S[event_col].astype(int)
    # tiny jitter to break ties
    jitter = (np.nanpercentile(real_tmp[time_col], 75) - np.nanpercentile(real_tmp[time_col], 25) + 1e-9) * 1e-6
    S[time_col] = S[time_col].values + np.random.default_rng(seed).normal(0.0, jitter, size=len(S))
    S[time_col] = np.clip(S[time_col], 0.0, xmax)
    return S

def _surv_at(times, surv, x):
    if x <= 0: return 1.0
    idx = np.searchsorted(times, x, side="right") - 1
    idx = max(idx, 0)
    return float(surv[idx])

def km_plot_by_treatment(real_df, synth_df, time_col, event_col, treat_col, tag):
    from matplotlib.lines import Line2D
    km_real, xmax = build_km_by_treatment(real_df, time_col, event_col, treat_col)
    km_syn,  _    = build_km_by_treatment(
        synth_df.assign(**{time_col: np.minimum(synth_df[time_col].values, xmax)}),
        time_col, event_col, treat_col
    )
    colors = {0: "#1f77b4", 1: "#2ca02c"}  # blue, green

    plt.figure(figsize=(7.8, 5.0))
    for g in [0, 1]:
        if g not in km_real or g not in km_syn:
            continue
        rt, rp = km_real[g]; st, sp = km_syn[g]
        plt.step(rt, rp, where="post", color=colors[g], linewidth=2.0)
        plt.step(st, sp, where="post", color=colors[g], linewidth=2.0, linestyle="--")
        for d, label in [(365, "1 yr"), (1095, "3 yr")]:
            if d <= xmax:
                r_val = _surv_at(rt, rp, d)
                s_val = _surv_at(st, sp, d)
                voff = 0.03 if g == 1 else -0.06
                plt.text(d + 10, r_val + voff, f"{label} (R {r_val:.2f}, V {s_val:.2f})",
                         color=colors[g], fontsize=9)

    for d in [365, 1095]:
        if d <= xmax: plt.axvline(d, linestyle=":", color="red", alpha=0.35)

    plt.xlim(0, xmax * 1.02); plt.ylim(0.6, 1.0)
    plt.xlabel("Time (days)"); plt.ylabel("Survival Probability")
    plt.title("Kaplan–Meier: Real vs Synthetic by treatment")

    treatment_handles = [
        Line2D([0], [0], color=colors[0], lw=3, label="Treat 0"),
        Line2D([0], [0], color=colors[1], lw=3, label="Treat 1"),
    ]
    style_handles = [
        Line2D([0], [0], color="k", lw=2, linestyle="-",  label="Real"),
        Line2D([0], [0], color="k", lw=2, linestyle="--", label="Synthetic (TabGraphSyn)"),
    ]
    leg1 = plt.legend(handles=treatment_handles, title="Treatment", loc="lower left")
    plt.gca().add_artist(leg1)
    plt.legend(handles=style_handles, title="Data Source", loc="lower right")

    plt.tight_layout()
    fname = f"{OUTDIR}/survival_km_{tag.lower()}.png"
    plt.savefig(fname, dpi=220); plt.close()
    return fname

# ---------------------------
# Diffusion helpers
# ---------------------------
def cosine_beta_schedule(T, s=0.008):
    steps = T + 1
    x = torch.linspace(0, T, steps, device=device)
    alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 1e-5, 0.999)

# ---------------------------
# VAE extra losses for marginals
# ---------------------------
def batch_mde_loss_1d(x_rec, x_true):
    """
    Approximate marginal KS/W1 with batch-quantile L1:
    Sort per feature along batch, then mean |Δ| across features.
    """
    # x_* shape: [B, D]
    x1, _ = torch.sort(x_rec, dim=0)
    x2, _ = torch.sort(x_true, dim=0)
    return F.l1_loss(x1, x2, reduction='mean')

def batch_moment_loss(x_rec, x_true, eps=1e-8):
    """
    Mean/std alignment across batch, averaged over features.
    """
    m1 = x_rec.mean(dim=0)
    s1 = x_rec.std(dim=0) + eps
    m2 = x_true.mean(dim=0)
    s2 = x_true.std(dim=0) + eps
    return F.mse_loss(m1, m2) + F.mse_loss(s1, s2)

# ---------------------------
# Marginal calibration (post-decode)
# ---------------------------
def marginal_calibrate_blend(real_df, synth_df, binary_cols, time_col, event_col, treat_col,
                             alpha_marg=ALPHA_MARG, holdout_frac=HOLDOUT_MARG, bin_shrink=BIN_SHRINK, seed=SEED):
    """
    Holdout-aware, feature-wise calibration:
      - For non-binary (excluding time/event): partial rank-preserving quantile map.
      - For binary (excluding event): top-k rounding with shrinkage to real prevalence.
    """
    rng = np.random.default_rng(seed)
    real = real_df.copy()
    # Split calibration / holdout (stratify by treat to avoid drift)
    idx = np.arange(len(real))
    tr_idx, _ = train_test_split(idx, test_size=holdout_frac, random_state=SEED,
                                 stratify=real[treat_col] if treat_col in real.columns else None)
    calib = real.iloc[tr_idx].reset_index(drop=True)

    S = synth_df.copy()
    # ----- Binary columns -----
    bin_cols = [c for c in binary_cols if c in S.columns and c not in {event_col}]
    for c in bin_cols:
        # Don't remap 'treat' here; its distribution is enforced by generation
        if c == treat_col:
            # still snap to {0,1} using rank-based top-k at current prevalence
            vals = S[c].to_numpy()
            order = np.argsort(vals)
            k = int(round(np.clip(np.mean(S[c]>0.5)*len(S), 0, len(S))))
            S.loc[order[:len(S)-k], c] = 0
            S.loc[order[len(S)-k:], c] = 1
            continue
        p_real = float(np.mean(calib[c].astype(int)))
        p_curr = float(np.mean((S[c].values > 0.5).astype(int)))
        p_tgt  = (1.0 - bin_shrink) * p_curr + bin_shrink * p_real
        k = int(round(np.clip(p_tgt * len(S), 0, len(S))))
        vals = S[c].to_numpy()
        order = np.argsort(vals)  # largest become 1
        S.loc[order[:len(S)-k], c] = 0
        S.loc[order[len(S)-k:], c] = 1
        S[c] = S[c].astype(int)

    # ----- Non-binary columns (exclude time & event) -----
    nb_cols = [c for c in S.columns if c not in set(bin_cols) | {time_col, event_col}]
    for c in nb_cols:
        r = calib[c].values
        s = S[c].values
        mapped = _quantile_map_to_ref(s, r)
        S[c] = (1.0 - alpha_marg) * s + alpha_marg * mapped

    return S

# ---------------------------
# Main pipeline
# ---------------------------
def run_pipeline_from_df(full_df, time_col="time", event_col="event", treat_col="treat",
                         tag="AIDS", binary_cols=None):
    # 1) Prepare matrices
    feat_cols = [c for c in full_df.columns if c != event_col]
    X = full_df[feat_cols].copy()
    scaler = StandardScaler()
    Xs = scaler.fit_transform(X.values.astype(np.float32)).astype(np.float32)
    feat_names = list(feat_cols)

    # Conditional label = treatment (fallback to KMeans if constant)
    tcol = full_df[treat_col]
    if isinstance(tcol, pd.DataFrame): tcol = tcol.iloc[:,0]
    treat = tcol.values.astype(int) if treat_col in full_df.columns else np.zeros(len(full_df), dtype=int)

    X_train, X_val, treat_train, treat_val = train_test_split(
        Xs, treat, test_size=0.2,
        stratify=treat if len(np.unique(treat))>1 else None,
        random_state=SEED
    )

    X_all_t   = torch.tensor(Xs, device=device)
    n_nodes, n_features = Xs.shape

    # 2) kNN graph
    nbrs = NearestNeighbors(n_neighbors=KNN_K+1).fit(Xs)
    _, idxs = nbrs.kneighbors(Xs)
    A = np.zeros((n_nodes, n_nodes), dtype=np.float32)
    for i in range(n_nodes):
        for j in idxs[i][1:]:
            A[i, j] = 1.0; A[j, i] = 1.0
    I = np.eye(n_nodes, dtype=np.float32)
    A_hat = A + I
    D = np.sum(A_hat, axis=1)
    D_inv_sqrt = 1.0 / np.sqrt(D + 1e-8)
    A_norm = D_inv_sqrt[:, None] * A_hat * D_inv_sqrt[None, :]
    A_norm_t = torch.tensor(A_norm, device=device)

    # 3) GCN conditioning
    y = treat.copy()
    if len(np.unique(y)) < 2:
        from sklearn.cluster import KMeans
        y = KMeans(n_clusters=2, random_state=SEED).fit_predict(Xs)
    y_all_t = torch.tensor(y, device=device)

    idx_all = np.arange(n_nodes)
    train_idx, val_idx = train_test_split(idx_all, test_size=0.2, stratify=y, random_state=SEED)
    train_mask = np.zeros(n_nodes, dtype=bool); train_mask[train_idx] = True
    val_mask   = np.zeros(n_nodes, dtype=bool); val_mask[val_idx]   = True
    train_mask_t = torch.tensor(train_mask, device=device)
    val_mask_t   = torch.tensor(val_mask,   device=device)

    gcn = GCN(n_features).to(device)
    opt_gcn = torch.optim.AdamW(gcn.parameters(), lr=GCN_LR, weight_decay=GCN_WD)
    for ep in range(GCN_EPOCHS):
        gcn.train(); opt_gcn.zero_grad()
        logits, _ = gcn(A_norm_t, X_all_t)
        loss = F.cross_entropy(logits[train_mask_t], y_all_t[train_mask_t])
        loss.backward(); torch.nn.utils.clip_grad_norm_(gcn.parameters(), 1.0)
        opt_gcn.step()
    with torch.no_grad():
        gcn.eval(); _, cond_all = gcn(A_norm_t, X_all_t)
    cond_dim = cond_all.shape[1]

    # class prototypes
    cond_class = []
    classes = np.unique(y)
    for c in classes:
        mask = (y_all_t == int(c))
        cond_class.append(cond_all[mask].mean(dim=0, keepdim=True))
    cond_class = torch.cat(cond_class, dim=0)

    # 4) VAE with KL warm-up + MDE/Moment losses
    def vae_step_losses(rec, xb, mu, lv, beta):
        _, _, _, _, rl_dummy, kld_dummy = 0,0,0,0,0,0  # just for naming clarity
        # main recon/KL are computed in forward; recompute here for clarity
        rl = F.mse_loss(rec, xb, reduction='mean')
        kld = -0.5 * torch.mean(1 + lv - mu.pow(2) - lv.exp())
        mde = batch_mde_loss_1d(rec, xb)
        mom = batch_moment_loss(rec, xb)
        return rl + beta*kld + MDE_W*mde + MOMENT_W*mom, rl.item(), kld.item(), mde.item(), mom.item()

    vae = VAE(in_dim=n_features, z_dim=VAE_Z, hidden=VAE_H).to(device)
    opt_vae = torch.optim.AdamW(vae.parameters(), lr=VAE_LR, weight_decay=VAE_WD)
    X_train_t = torch.tensor(X_train, device=device)
    bs = VAE_BS
    for ep in range(VAE_EPOCHS):
        vae.train()
        beta = BETA_MAX * min(1.0, ep / max(1, int(0.4*VAE_EPOCHS)))  # warm-up over 40% epochs
        idx = torch.randperm(X_train_t.shape[0], device=device)
        for i in range(0, len(idx), bs):
            xb = X_train_t[idx[i:i+bs]]
            opt_vae.zero_grad(); rec, mu, lv, z, _, _ = vae(xb, beta=beta)
            loss, rl_v, kld_v, mde_v, mom_v = vae_step_losses(rec, xb, mu, lv, beta)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            opt_vae.step()
    for p in vae.parameters(): p.requires_grad = False
    vae.eval()
    with torch.no_grad():
        mu_all, lv_all = vae.encode(torch.tensor(Xs, device=device)); z_all = mu_all
    z_train = z_all[train_mask_t]

    # 5) DDPM
    betas = cosine_beta_schedule(TSTEPS)
    alphas = 1.0 - betas
    acp = torch.cumprod(alphas, 0)
    sqrt_acp    = torch.sqrt(acp)
    sqrt_1m_acp = torch.sqrt(1.0 - acp)
    def q_sample(x0, t, noise):
        return sqrt_acp[t].unsqueeze(1)*x0 + sqrt_1m_acp[t].unsqueeze(1)*noise

    eps_net = EpsNet(z_dim=z_all.shape[1], c_dim=cond_dim).to(device)
    opt_eps = torch.optim.AdamW(eps_net.parameters(), lr=EPS_LR, weight_decay=EPS_WD)
    ema_net = EpsNet(z_dim=z_all.shape[1], c_dim=cond_dim).to(device)
    ema_net.load_state_dict(eps_net.state_dict())
    def update_ema(student, teacher, decay=0.9995):
        for p_s, p_t in zip(student.parameters(), teacher.parameters()):
            p_t.data.mul_(decay).add_(p_s.data, alpha=1.0 - decay)

    for ep in range(DDPM_EPOCHS):
        eps_net.train(); perm = torch.randperm(z_train.shape[0], device=device)
        for i in range(0, len(perm), 128):
            idx = perm[i:i+128]
            x0 = z_train[idx]
            c  = cond_all[train_mask_t][idx] + COND_NOISE * torch.randn_like(cond_all[train_mask_t][idx])
            t  = torch.randint(low=0, high=TSTEPS, size=(x0.shape[0],), device=device, dtype=torch.long)
            noise = torch.randn_like(x0)
            x_t = q_sample(x0, t, noise)
            drop = (torch.rand(x0.size(0), device=device) < DROP_UNCOND).float().unsqueeze(1)
            c_step = c * (1.0 - drop)
            opt_eps.zero_grad(); pred = eps_net(x_t, c_step, t, TSTEPS)
            loss = F.mse_loss(pred, noise); loss.backward()
            torch.nn.utils.clip_grad_norm_(eps_net.parameters(), 1.0)
            opt_eps.step(); update_ema(eps_net, ema_net)

    @torch.no_grad()
    def p_mean_var_cfg(x_t, t, c, w=CFG_W):
        beta_t = betas[t].unsqueeze(1)
        sqrt_recip_alpha = (1.0/torch.sqrt(alphas[t])).unsqueeze(1)
        sqrt_1m = sqrt_1m_acp[t].unsqueeze(1)
        net = ema_net
        eps_c = net(x_t, c, t, TSTEPS)
        eps_u = net(x_t, torch.zeros_like(c), t, TSTEPS)
        eps_hat = (1.0 + w) * eps_c - w * eps_u
        mean = sqrt_recip_alpha * (x_t - (beta_t / sqrt_1m) * eps_hat)
        var  = betas[t].unsqueeze(1)
        return mean, var

    @torch.no_grad()
    def sample_latents_cond(n_per_class, w=CFG_W, tau=TAU_NOISE):
        xs=[]; ys=[]
        for cls, n_ in enumerate(n_per_class):
            if n_ <= 0: continue
            x = torch.randn(n_, z_all.shape[1], device=device)
            c = cond_class[cls].expand(n_, -1)
            for t in reversed(range(TSTEPS)):
                t_b = torch.full((n_,), t, device=device, dtype=torch.long)
                mean, var = p_mean_var_cfg(x, t_b, c, w=w)
                if t>0: x = mean + torch.sqrt(var) * tau * torch.randn_like(x)
                else:   x = mean
            xs.append(x); ys.append(np.full(n_, cls, dtype=int))
        return (torch.cat(xs, 0) if xs else torch.empty(0, z_all.shape[1], device=device),
                np.concatenate(ys) if ys else np.array([], dtype=int))

    # Generate ~3x synthetic, class-matched by treatment proportions
    unique, counts = np.unique(y, return_counts=True)
    prop = counts / counts.sum()
    n_samples = len(y) * 3
    n_per_class = (prop * n_samples).astype(int)
    while n_per_class.sum() < n_samples:
        n_per_class[np.argmax(prop)] += 1

    z_synth, y_synth = sample_latents_cond(n_per_class)
    with torch.no_grad():
        x_synth_std0 = vae.decode(z_synth).cpu().numpy().astype(np.float32)

    # CORAL (standardized space) + blend + jitter
    cov_reg = CORAL_REG
    def sym_eig(mat, eps=1e-8):
        w, v = np.linalg.eigh((mat + mat.T) * 0.5)
        w = np.clip(w, eps, None)
        return w, v
    mu_r = Xs.mean(0); Cr = np.cov(Xs, rowvar=False) + cov_reg*np.eye(n_features)
    mu_s = x_synth_std0.mean(0); Cs = np.cov(x_synth_std0, rowvar=False) + cov_reg*np.eye(n_features)
    ws, Vs = sym_eig(Cs); wt, Vt = sym_eig(Cr)
    Cs_inv_sqrt = Vs @ np.diag(1.0/np.sqrt(ws)) @ Vs.T
    Cr_sqrt     = Vt @ np.diag(np.sqrt(wt))     @ Vt.T
    Xs0 = x_synth_std0 - mu_s
    Xs_whiten  = Xs0 @ Cs_inv_sqrt
    Xs_coral   = Xs_whiten @ Cr_sqrt + mu_r
    Xs_blended = (1.0 - CORAL_BLEND) * x_synth_std0 + CORAL_BLEND * Xs_coral
    Xs_blended += JITTER_SD * np.random.randn(*Xs_blended.shape).astype(Xs_blended.dtype)
    synth_std_stage = Xs_blended.astype(np.float32)

    # Inverse-scale → original feature space
    synth_inv = scaler.inverse_transform(synth_std_stage)
    synth_df = pd.DataFrame(synth_inv, columns=feat_names)

    # Ensure 'treat' follows generated class, snap to {0,1}; placeholder for event
    if treat_col in synth_df.columns:
        synth_df[treat_col] = (y_synth[:len(synth_df)] if len(y_synth)>=len(synth_df)
                               else np.rint(synth_df[treat_col]).astype(int)).astype(int)
    if "event" not in synth_df.columns:
        synth_df["event"] = 0

    # ----- NEW: Marginal calibration (binary + non-binary; excludes time/event) -----
    real_for_marg = full_df.copy()
    synth_df = marginal_calibrate_blend(
        real_for_marg, synth_df, binary_cols=(binary_cols or []),
        time_col=time_col, event_col=event_col, treat_col=treat_col,
        alpha_marg=ALPHA_MARG, holdout_frac=HOLDOUT_MARG, bin_shrink=BIN_SHRINK, seed=SEED
    )

    # --- Survival calibration (times + event) ---
    real_surv_df = full_df[[time_col, event_col, treat_col]].copy()
    synth_df = survival_calibrate_times_events_shrinkage(
        real_surv_df, synth_df,
        time_col=time_col, event_col=event_col, treat_col=treat_col,
        lambda_rate=LAMBDA_RATE, alpha_time=ALPHA_TIME, holdout_frac=HOLDOUT_SURV, seed=SEED
    )

    # Ensure positive times
    synth_df.loc[synth_df[time_col] <= 0, time_col] = max(1.0, float(np.median(synth_df[time_col])))

    # Save calibrated synthetic (original scale)
    csv_out = f"{OUTDIR}/synthetic_aids_with_survival.csv"
    synth_df.to_csv(csv_out, index=False)

    # ----- IMPORTANT: Recompute standardized synthetic AFTER all calibrations -----
    synth_std = scaler.transform(synth_df[feat_cols].values.astype(np.float32)).astype(np.float32)

    # 6.1) Fidelity metrics
    MDE_per_var_df, MDE_mean_pct = marginal_KS_percent(Xs, synth_std, feat_names)
    P_corr_df, P_corr_mean_pct   = pairwise_corr_error_percent(Xs, synth_std, feat_names)
    MDE_per_var_df.to_csv(f"{OUTDIR}/marginal_errors_per_variable_aids.csv", index=False)
    P_corr_df.to_csv(f"{OUTDIR}/pairwise_corr_errors_aids.csv", index=False)
    table2 = pd.DataFrame([["TabGraphSyn (marginal-aware)", round(MDE_mean_pct, 2), round(P_corr_mean_pct, 2)]],
                          columns=["Method", "Marginal Distribution Errors (%)", "Pairwise Correlation Errors (%)"])
    save_table_as_image(table2, "TABLE II: Statistical Fidelity — AIDS", f"{OUTDIR}/table2_aids.png")

    # 6.2) UMAP
    umap_path = save_umap_paperstyle(Xs, synth_std, tag=tag, panel_label="(b) AIDS", xlim=(-10,15), ylim=(-5,15))

    # 6.3) Correlation heatmaps
    corr_path = save_corr_heatmaps_with_labels(Xs, synth_std, feat_names, tag=tag)

    # 6.4) Survival analysis & Cox confusion (statsmodels)
    real_df = pd.concat([X.reset_index(drop=True), full_df[[event_col]].reset_index(drop=True)], axis=1)
    km_path = km_plot_by_treatment(real_df, synth_df, time_col=time_col, event_col=event_col, treat_col=treat_col, tag=tag)
    cm, sig_real, sig_syn = cox_significance_confusion(real_df, synth_df, time_col=time_col, event_col=event_col)
    save_table_as_image(cm, "TABLE V: CoxPH significant covariates — Confusion Matrix (AIDS)",
                        f"{OUTDIR}/table5_aids_cox_confusion.png")

    # 6.5) Detection score
    det_log_acc, det_log_auc = detection_score_logistic(Xs, synth_std)
    det_best_acc, det_best_auc, det_best_name = detection_score_bestof(Xs, synth_std, seed=SEED)
    table6a = pd.DataFrame([["TabGraphSyn (LogReg)", f"{det_log_acc:.2f}%", f"{det_log_auc:.3f}"]], columns=["Model", "Detection Acc (%)", "AUC"])
    save_table_as_image(table6a, "TABLE VI-A: Detection (Logistic) — AIDS", f"{OUTDIR}/table6_aids_detection_logreg.png")
    table6b = pd.DataFrame([[f"Best-of ({det_best_name})", f"{det_best_acc:.2f}%", f"{det_best_auc:.3f}"]], columns=["Model", "Detection Acc (%)", "AUC"])
    save_table_as_image(table6b, "TABLE VI-B: Detection (Best-of) — AIDS", f"{OUTDIR}/table6_aids_detection_bestof.png")

    # Per-variable marginal error table image
    save_table_as_image(MDE_per_var_df.round({"KS%":2}),
                        "Marginal Errors per Variable — AIDS",
                        f"{OUTDIR}/marginal_errors_per_variable_aids.png")

    # Console summary
    print("\n=== AIDS (ACTG175) — FINAL METRICS (marginal-aware) ===")
    print(pd.Series({
        "Marginal KS mean %": MDE_mean_pct,
        "Pairwise |Δρ| mean %": P_corr_mean_pct,
        "Detection Acc (%) — Logistic": det_log_acc,
        "Detection AUC — Logistic": det_log_auc,
        "Detection Acc (%) — Best-of": det_best_acc,
        "Detection AUC — Best-of": det_best_auc,
        "Best-of attacker": det_best_name,
    }))
    print("Saved:")
    print(" • Synthetic CSV:", csv_out)
    print(" • UMAP:", umap_path)
    print(" • Heatmaps:", corr_path)
    print(" • Table II:", f"{OUTDIR}/table2_aids.png")
    print(" • Marginal per-variable (CSV):", f"{OUTDIR}/marginal_errors_per_variable_aids.csv")
    print(" • Pairwise errors (CSV):", f"{OUTDIR}/pairwise_corr_errors_aids.csv")
    print(" • Survival KM:", f"{OUTDIR}/survival_km_{tag.lower()}.png")
    print(" • Cox confusion:", f"{OUTDIR}/table5_aids_cox_confusion.png")
    print(" • Detection Logistic:", f"{OUTDIR}/table6_aids_detection_logreg.png")
    print(" • Detection Best-of:", f"{OUTDIR}/table6_aids_detection_bestof.png")

def cox_significance_confusion(real_df, synth_df, time_col, event_col):
    def _fit(df):
        exog = df.drop(columns=[time_col, event_col]).copy()
        exog = sm.add_constant(exog, has_constant='add')
        res = PHReg(df[time_col].values, exog.values, status=df[event_col].values).fit()
        names = list(exog.columns)
        pvals = np.asarray(res.pvalues).ravel()
        summ = pd.DataFrame({"feature": names, "p": pvals})
        sig = set(summ.loc[summ['p'] < 0.05, 'feature'].tolist())
        sig.discard('const')
        return sig, summ

    sig_real, summ_real = _fit(real_df)
    sig_syn,  summ_syn  = _fit(synth_df)
    features = sorted(set(summ_real['feature']).union(set(summ_syn['feature'])))
    tp = len(sig_real & sig_syn)
    fp = len(sig_syn - sig_real)
    fn = len(sig_real - sig_syn)
    tn = len(set(features) - (sig_real | sig_syn))
    prec = tp / (tp + fp + 1e-12)
    rec  = tp / (tp + fn + 1e-12)
    f1   = 2*prec*rec / (prec + rec + 1e-12)
    cm = pd.DataFrame({
        "Metric":["True Pos.","False Pos.","False Neg.","True Neg.","Precision","Recall","F1 Score"],
        "TabGraphSyn":[tp, fp, fn, tn, f"{prec:.3f}", f"{rec:.3f}", f"{f1:.3f}"]
    })
    return cm, sig_real, sig_syn

def run_pipeline_from_ucirepo(uci_id: int = 890):
    full_df, X, T, E, feat_names, bin_cols, colnames, tag = load_actg175_from_ucirepo(uci_id=uci_id)
    return run_pipeline_from_df(
        full_df,
        time_col=colnames["time"],
        event_col=colnames["event"],
        treat_col=colnames["treat"],
        tag=tag,
        binary_cols=bin_cols
    )

# ---------------------------
# CLI
# ---------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--uci_id", type=int, default=890, help="UCI dataset id (default: 890 for ACTG175).")
    args, _unknown = parser.parse_known_args()
    run_pipeline_from_ucirepo(args.uci_id)


=== UCI metadata (short) ===
{'name': 'AIDS Clinical Trials Group Study 175'}
=== Variables (head) ===
     name     role     type demographic  \
0  pidnum       ID  Integer        None   
1     cid   Target   Binary        None   
2    time  Feature  Integer        None   
3     trt  Feature  Integer        None   
4     age  Feature  Integer         Age   

                                         description units missing_values  
0                                         Patient ID  None             no  
1   censoring indicator (1 = failure, 0 = censoring)  None             no  
2                       time to failure or censoring  None             no  
3  treatment indicator (0 = ZDV only; 1 = ZDV + d...  None             no  
4                              age (yrs) at baseline  None             no  


  c /= stddev[:, None]
  c /= stddev[None, :]
  c /= stddev[:, None]
  c /= stddev[None, :]



=== AIDS (ACTG175) — FINAL METRICS (marginal-aware) ===
Marginal KS mean %               5.163132
Pairwise |Δρ| mean %             3.427685
Detection Acc (%) — Logistic    71.495327
Detection AUC — Logistic         0.721616
Detection Acc (%) — Best-of     99.579166
Detection AUC — Best-of          0.999927
Best-of attacker                   RF-400
dtype: object
Saved:
 • Synthetic CSV: /data/user/home/rkhan5/AIDS/outputs_aids/synthetic_aids_with_survival.csv
 • UMAP: /data/user/home/rkhan5/AIDS/outputs_aids/umap_aids_paperstyle.png
 • Heatmaps: /data/user/home/rkhan5/AIDS/outputs_aids/corr_heatmaps_aids_labeled.png
 • Table II: /data/user/home/rkhan5/AIDS/outputs_aids/table2_aids.png
 • Marginal per-variable (CSV): /data/user/home/rkhan5/AIDS/outputs_aids/marginal_errors_per_variable_aids.csv
 • Pairwise errors (CSV): /data/user/home/rkhan5/AIDS/outputs_aids/pairwise_corr_errors_aids.csv
 • Survival KM: /data/user/home/rkhan5/AIDS/outputs_aids/survival_km_aids.png
 • Cox confusion: /d