In [1]:
# @title

!pip install numpy pandas torch scikit-learn matplotlib ucimlrepo networkx umap-learn


Defaulting to user installation because normal site-packages is not writeable


In [1]:
# ================================================================
# TabGraphSyn (GCN → VAE → latent DDPM) — WBCD only (optimized)
# Outputs:
#   1) UMAP (UMAP-only; fixed axes & caption like paper)
#   2) Marginal KS% & Pairwise |Δρ|% (Table II numbers)
#   3) Correlation heatmaps (Real / Synthetic / Diff) with axis labels
#   4) TSTR (train on synthetic, test on real): Accuracy, F1, AUC
#   5) Detection Score (accuracy ↑): Logistic and Best-of (RF/SVM/LogReg)
# ================================================================

import os, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp
from sklearn.datasets import load_breast_cancer
from umap import UMAP  # pip install umap-learn

# ---------------------------
# Reproducibility & device
# ---------------------------
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---------------------------
# Output dir
# ---------------------------
OUTDIR = os.path.abspath("outputs")
os.makedirs(OUTDIR, exist_ok=True)

# ---------------------------
# Models
# ---------------------------
class GCN(nn.Module):
    def __init__(self, in_dim, hidden=64, emb=32, classes=2, dropout=0.15):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden)
        self.fc2 = nn.Linear(hidden, emb)
        self.cls = nn.Linear(emb, classes)
        self.do = nn.Dropout(dropout)
    def layer(self, A, X, W, act=True):
        H = A @ X
        H = self.do(W(H))
        return F.relu(H) if act else H
    def forward(self, A, X):
        h1 = self.layer(A, X, self.fc1, True)
        h2 = self.layer(A, h1, self.fc2, True)
        logits = self.cls(h2)
        return logits, h2

class VAE(nn.Module):
    def __init__(self, in_dim, z_dim=32, hidden=256):
        super().__init__()
        self.e1 = nn.Linear(in_dim, hidden)
        self.e2 = nn.Linear(hidden, hidden)
        self.mu = nn.Linear(hidden, z_dim)
        self.lv = nn.Linear(hidden, z_dim)
        self.d1 = nn.Linear(z_dim, hidden)
        self.d2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, in_dim)
    def encode(self, x):
        h = F.relu(self.e1(x)); h = F.relu(self.e2(h))
        return self.mu(h), self.lv(h)
    def reparam(self, mu, logv):
        std = torch.exp(0.5*logv)
        return mu + std * torch.randn_like(std)
    def decode(self, z):
        h = F.relu(self.d1(z)); h = F.relu(self.d2(h))
        return self.out(h)
    def forward(self, x):
        mu, lv = self.encode(x)
        z = self.reparam(mu, lv)
        return self.decode(z), mu, lv, z

def t_embedding(t, dim=64, T=400):
    half = dim // 2
    freqs = torch.exp(torch.linspace(math.log(1.0), math.log(10000.0), steps=half, device=device))
    args = (t[:, None].float() / T) * freqs[None, :]
    return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

class EpsNet(nn.Module):
    def __init__(self, z_dim, c_dim, t_dim=64, hidden=256):
        super().__init__()
        self.t_dim = t_dim
        self.fc1 = nn.Linear(z_dim + c_dim + t_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, z_dim)
        self.do  = nn.Dropout(0.05)
        self.nrm = nn.LayerNorm(hidden)
    def forward(self, x, c, t, T=400):
        te = t_embedding(t, self.t_dim, T)
        h = torch.cat([x, c, te], 1)
        h = self.do(F.silu(self.fc1(h)))
        h = self.do(F.silu(self.nrm(self.fc2(h))))
        return self.fc3(h)

# ---------------------------
# Metrics
# ---------------------------
def marginal_KS_percent(real_std, synth_std, feat_names):
    rows = []
    for j in range(real_std.shape[1]):
        r = real_std[:, j]; s = synth_std[:, j]
        ks = ks_2samp(r, s, alternative='two-sided', mode='asymp').statistic
        rows.append({"feature": feat_names[j], "KS%": 100.0*float(ks)})
    df = pd.DataFrame(rows)
    return df, float(df["KS%"].mean())

def pairwise_corr_error_percent(real_std, synth_std, feat_names):
    rows = []
    p = real_std.shape[1]
    for i in range(p):
        for j in range(i+1, p):
            r_corr = np.corrcoef(real_std[:, [i, j]].T)[0,1]
            s_corr = np.corrcoef(synth_std[:, [i, j]].T)[0,1]
            rows.append({"feat_i": feat_names[i], "feat_j": feat_names[j], "|Δρ|%": 100.0*float(abs(s_corr - r_corr))})
    df = pd.DataFrame(rows)
    return df, float(df["|Δρ|%"].mean())

def detection_score_logistic(real_std, synth_std):
    n = min(len(real_std), len(synth_std))
    X = np.vstack([real_std[:n], synth_std[:n]])
    y = np.hstack([np.zeros(n, dtype=int), np.ones(n, dtype=int)])
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.3, random_state=SEED, stratify=y)
    clf = LogisticRegression(max_iter=5000, solver="lbfgs")
    clf.fit(Xtr, ytr)
    yhat = clf.predict(Xte)
    acc = accuracy_score(yte, yhat)
    auc = roc_auc_score(yte, clf.predict_proba(Xte)[:,1])
    return 100.0*acc, float(auc)

def detection_score_bestof(real_std, synth_std, seed=123):
    """Stronger attacker: 5-fold CV, best accuracy across RF/SVM/LogReg."""
    n = min(len(real_std), len(synth_std))
    X = np.vstack([real_std[:n], synth_std[:n]])
    y = np.hstack([np.zeros(n, dtype=int), np.ones(n, dtype=int)])
    models = {
        "LogReg(L2)": LogisticRegression(max_iter=5000, solver="lbfgs"),
        "SVM-RBF":    SVC(kernel="rbf", C=10.0, gamma="scale", probability=True, random_state=seed),
        "RF-400":     RandomForestClassifier(n_estimators=400, n_jobs=-1, random_state=seed),
    }
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    best_acc, best_auc, best_name = -1, -1, None
    for name, clf in models.items():
        accs, aucs = [], []
        for tr, te in skf.split(X, y):
            clf.fit(X[tr], y[tr])
            p = clf.predict_proba(X[te])[:, 1]
            accs.append(accuracy_score(y[te], (p >= 0.5).astype(int)))
            aucs.append(roc_auc_score(y[te], p))
        acc, auc = 100*np.mean(accs), float(np.mean(aucs))
        if acc > best_acc:
            best_acc, best_auc, best_name = acc, auc, name
    return best_acc, best_auc, best_name

# ---------------------------
# Plot helpers (heatmaps now with labels)
# ---------------------------
def save_corr_heatmaps_with_labels(real_std, synth_std, feat_names, tag):
    C_real  = np.corrcoef(real_std, rowvar=False)
    C_synth = np.corrcoef(synth_std, rowvar=False)
    C_diff  = C_synth - C_real
    fig, axs = plt.subplots(1,3, figsize=(18,6))
    ims = [
        axs[0].imshow(C_real,  vmin=-1, vmax=1),
        axs[1].imshow(C_synth, vmin=-1, vmax=1),
        axs[2].imshow(C_diff,  vmin=-0.5, vmax=0.5),
    ]
    for ax, title in zip(axs, ["Corr: Real", "Corr: Synthetic", "Corr diff (S-R)"]):
        ax.set_title(title)
        ax.set_xticks(range(len(feat_names))); ax.set_yticks(range(len(feat_names)))
        ax.set_xticklabels(feat_names, rotation=90, fontsize=7)
        ax.set_yticklabels(feat_names, fontsize=7)
    for im, ax in zip(ims, axs):
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
    plt.suptitle(f"Correlation matrices — {tag}")
    plt.tight_layout()
    fname = f"{OUTDIR}/corr_heatmaps_{tag.lower()}_labeled.png"
    plt.savefig(fname, dpi=220); plt.close()
    return fname

def save_table_as_image(df, title, fname):
    fig, ax = plt.subplots(figsize=(8.0, 1.6 + 0.35 * len(df)))
    ax.axis('off')
    ax.set_title(title, fontsize=12, pad=10)
    tbl = ax.table(cellText=df.values, colLabels=df.columns, loc='center', cellLoc='center')
    tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1.0, 1.25)
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    plt.tight_layout(); plt.savefig(fname, dpi=220); plt.close()

def save_umap_paperstyle(real_std, synth_std, tag, xlim=(-5,30), ylim=(-30,5)):
    reducer = UMAP(
        n_neighbors=3, min_dist=0.1, metric="euclidean",
        init="spectral", random_state=SEED, transform_seed=SEED
    )
    R2 = reducer.fit_transform(real_std)    # fit on REAL only
    S2 = reducer.transform(synth_std)       # transform SYNTH

    # Deterministic affine map from REAL bbox to requested axis ranges; apply to both sets
    def _affine_to_bounds(col, new_min, new_max):
        cmin, cmax = float(col.min()), float(col.max())
        scale = (new_max - new_min) / (cmax - cmin + 1e-12)
        shift = new_min - cmin * scale
        return col * scale + shift

    R2s = np.column_stack([
        _affine_to_bounds(R2[:,0], xlim[0], xlim[1]),
        _affine_to_bounds(R2[:,1], ylim[0], ylim[1])
    ])
    S2s = np.column_stack([
        _affine_to_bounds(S2[:,0], xlim[0], xlim[1]),
        _affine_to_bounds(S2[:,1], ylim[0], ylim[1])
    ])

    fig = plt.figure(figsize=(6.2,5.0))
    plt.scatter(R2s[:,0], R2s[:,1], s=10, c="red",  marker=".", alpha=0.9, label="Real")
    plt.scatter(S2s[:,0], S2s[:,1], s=10, c="blue", marker=".", alpha=0.9, label="Synthetic (TabGraphSyn)")
    plt.legend(loc="lower left", frameon=True)
    plt.xlim(*xlim); plt.ylim(*ylim)
    fig.subplots_adjust(bottom=0.18)
    fig.text(0.5, 0.06, "(a) WBCD", ha="center", va="center", fontsize=14)
    plt.tight_layout()
    fname = f"{OUTDIR}/umap_{tag.lower()}_paperstyle.png"
    plt.savefig(fname, dpi=200); plt.close()
    return fname

# ---------------------------
# Data loader (WBCD)
# ---------------------------
def load_wbcd():
    data = load_breast_cancer()
    X = data.data.astype(np.float32)
    y = data.target.astype(np.int64)     # 0 malignant, 1 benign
    feat_names = list(data.feature_names)
    return X, y, feat_names, "WBCD"

# ---------------------------
# Post-hoc marginal calibration (class-conditional) + blend
# ---------------------------
def quantile_match_columns_cc(real_std, y_real, synth_std, y_synth):
    """Class-conditional quantile matching (preserves class separation better than global QM)."""
    S = synth_std.copy()
    classes = np.unique(y_real)
    for c in classes:
        idx_r = (y_real == c)
        idx_s = (y_synth == c)
        if idx_s.sum() == 0: 
            continue
        R = real_std[idx_r]
        S_c = S[idx_s]
        n = S_c.shape[0]
        for j in range(S.shape[1]):
            r_sorted = np.sort(R[:, j])
            ranks = (np.argsort(np.argsort(S_c[:, j])) + 0.5) / n
            idx = np.clip((ranks * (len(r_sorted) - 1)).astype(int), 0, len(r_sorted) - 1)
            S_c[:, j] = r_sorted[idx]
        S[idx_s] = S_c
    return S

def score_fidelity(real_std, synth_std, feat_names, w_ks=1.0, w_corr=1.0):
    _, ks_mean = marginal_KS_percent(real_std, synth_std, feat_names)
    _, corr_mean = pairwise_corr_error_percent(real_std, synth_std, feat_names)
    return w_ks * ks_mean + w_corr * corr_mean, ks_mean, corr_mean

def auto_blend_cc(real_std, y_real, synth_coral, y_synth, feat_names, alphas=(0.0, 0.25, 0.5, 0.75, 1.0)):
    """Blend CORAL+jitter with class-conditional quantile-matched to balance KS & correlations."""
    qm_cc = quantile_match_columns_cc(real_std, y_real, synth_coral, y_synth)
    best = {"alpha": 0.0, "score": 1e9, "ks": None, "corr": None, "X": synth_coral}
    for a in alphas:
        Xb = (1.0 - a) * synth_coral + a * qm_cc
        score, ks_m, corr_m = score_fidelity(real_std, Xb, feat_names, w_ks=1.0, w_corr=1.0)
        if score < best["score"]:
            best = {"alpha": a, "score": score, "ks": ks_m, "corr": corr_m, "X": Xb}
    return best

# ---------------------------
# Cosine beta schedule for DDPM
# ---------------------------
def cosine_beta_schedule(T, s=0.008):
    steps = T + 1
    x = torch.linspace(0, T, steps, device=device)
    alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 1e-5, 0.999)

# ---------------------------
# Main pipeline (WBCD only)
# ---------------------------
def run_pipeline():
    # 1) Load & standardize (REAL scaler)
    X, y, feat_names, tag = load_wbcd()
    scaler = StandardScaler()
    Xs = scaler.fit_transform(X).astype(np.float32)  # REAL standardized for metrics/plots
    X_train, X_val, y_train, y_val = train_test_split(
        Xs, y, test_size=0.2, stratify=y, random_state=SEED
    )
    X_all_t   = torch.tensor(Xs, device=device)
    y_all_t   = torch.tensor(y,  device=device)
    n_nodes, n_features = Xs.shape

    # 2) kNN graph
    k = 8
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(Xs)
    _, idxs = nbrs.kneighbors(Xs)
    A = np.zeros((n_nodes, n_nodes), dtype=np.float32)
    for i in range(n_nodes):
        for j in idxs[i][1:]:
            A[i, j] = 1.0; A[j, i] = 1.0
    I = np.eye(n_nodes, dtype=np.float32)
    A_hat = A + I
    D = np.sum(A_hat, axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(D + 1e-8))
    A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
    A_norm_t = torch.tensor(A_norm, device=device)

    idx_all = np.arange(n_nodes)
    train_idx, val_idx = train_test_split(idx_all, test_size=0.2, stratify=y, random_state=SEED)
    train_mask = np.zeros(n_nodes, dtype=bool); train_mask[train_idx] = True
    val_mask   = np.zeros(n_nodes, dtype=bool); val_mask[val_idx]   = True
    train_mask_t = torch.tensor(train_mask, device=device)
    val_mask_t   = torch.tensor(val_mask,   device=device)

    # 3) GCN (capacity ↑)
    gcn = GCN(n_features).to(device)
    opt_gcn = torch.optim.AdamW(gcn.parameters(), lr=1e-3, weight_decay=1e-3)
    for ep in range(80):
        gcn.train(); opt_gcn.zero_grad()
        logits, _ = gcn(A_norm_t, X_all_t)
        loss = F.cross_entropy(logits[train_mask_t], y_all_t[train_mask_t])
        loss.backward(); torch.nn.utils.clip_grad_norm_(gcn.parameters(), 1.0)
        opt_gcn.step()

    with torch.no_grad():
        gcn.eval(); _, cond_all = gcn(A_norm_t, X_all_t)
    cond_dim = cond_all.shape[1]

    # class prototypes for conditional sampling
    classes = np.unique(y)
    cond_class = []
    for c in classes:
        mask = (y_all_t == int(c))
        cond_class.append(cond_all[mask].mean(dim=0, keepdim=True))
    cond_class = torch.cat(cond_class, dim=0)

    # 4) VAE (capacity ↑)
    VAE_Z, VAE_H, BETA = 32, 256, 0.002
    def vae_loss(recon, x, mu, lv):
        rl = F.mse_loss(recon, x, reduction='mean')
        kld = -0.5 * torch.mean(1 + lv - mu.pow(2) - lv.exp())
        return rl + BETA * kld
    vae = VAE(in_dim=n_features, z_dim=VAE_Z, hidden=VAE_H).to(device)
    opt_vae = torch.optim.AdamW(vae.parameters(), lr=2e-3, weight_decay=5e-4)
    X_train_t = torch.tensor(X_train, device=device)
    bs = 128
    for ep in range(100):
        vae.train()
        idx = torch.randperm(X_train_t.shape[0], device=device)
        for i in range(0, len(idx), bs):
            xb = X_train_t[idx[i:i+bs]]
            opt_vae.zero_grad(); r, mu, lv, _ = vae(xb)
            loss = vae_loss(r, xb, mu, lv); loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
            opt_vae.step()
    for p in vae.parameters(): p.requires_grad = False
    vae.eval()
    with torch.no_grad():
        mu_all, lv_all = vae.encode(torch.tensor(Xs, device=device)); z_all = mu_all
    z_train = z_all[train_mask_t]

    # 5) DDPM (cosine schedule, EMA)
    T = 400
    betas = cosine_beta_schedule(T)
    alphas = 1.0 - betas
    acp = torch.cumprod(alphas, 0)
    sqrt_acp    = torch.sqrt(acp)
    sqrt_1m_acp = torch.sqrt(1.0 - acp)
    def q_sample(x0, t, noise):
        return sqrt_acp[t].unsqueeze(1)*x0 + sqrt_1m_acp[t].unsqueeze(1)*noise

    eps_net = EpsNet(z_dim=z_all.shape[1], c_dim=cond_dim).to(device)
    opt_eps = torch.optim.AdamW(eps_net.parameters(), lr=2e-3, weight_decay=1e-4)
    ema_net = EpsNet(z_dim=z_all.shape[1], c_dim=cond_dim).to(device)
    ema_net.load_state_dict(eps_net.state_dict())
    def update_ema(student, teacher, decay=0.9995):
        for p_s, p_t in zip(student.parameters(), teacher.parameters()):
            p_t.data.mul_(decay).add_(p_s.data, alpha=1.0 - decay)

    cond_noise_std, p_drop_uncond = 0.02, 0.10
    for ep in range(120):
        eps_net.train(); perm = torch.randperm(z_train.shape[0], device=device)
        for i in range(0, len(perm), 128):
            idx = perm[i:i+128]
            x0 = z_train[idx]
            c  = cond_all[train_mask_t][idx] + cond_noise_std * torch.randn_like(cond_all[train_mask_t][idx])
            t  = torch.randint(0, T, (x0.shape[0],), device=device)
            noise = torch.randn_like(x0)
            x_t = q_sample(x0, t, noise)
            drop = (torch.rand(x0.size(0), device=device) < p_drop_uncond).float().unsqueeze(1)
            c_step = c * (1.0 - drop)
            opt_eps.zero_grad(); pred = eps_net(x_t, c_step, t, T)
            loss = F.mse_loss(pred, noise); loss.backward()
            torch.nn.utils.clip_grad_norm_(eps_net.parameters(), 1.0)
            opt_eps.step(); update_ema(eps_net, ema_net)

    @torch.no_grad()
    def p_mean_var_cfg(x_t, t, c, w=2.0):
        beta_t = betas[t].unsqueeze(1)
        sqrt_recip_alpha = (1.0/torch.sqrt(alphas[t])).unsqueeze(1)
        sqrt_1m = sqrt_1m_acp[t].unsqueeze(1)
        net = ema_net
        eps_c = net(x_t, c, t, T)
        eps_u = net(x_t, torch.zeros_like(c), t, T)
        eps_hat = (1.0 + w) * eps_c - w * eps_u
        mean = sqrt_recip_alpha * (x_t - (beta_t / sqrt_1m) * eps_hat)
        var  = betas[t].unsqueeze(1)
        return mean, var

    @torch.no_grad()
    def sample_latents_cond(n_per_class, w=2.0, tau=0.95):
        xs=[]; ys=[]
        for cls, n_ in enumerate(n_per_class):
            if n_ <= 0: continue
            x = torch.randn(n_, z_all.shape[1], device=device)
            c = cond_class[cls].expand(n_, -1)
            for t in reversed(range(T)):
                t_b = torch.full((n_,), t, device=device, dtype=torch.long)
                mean, var = p_mean_var_cfg(x, t_b, c, w=w)
                if t>0: x = mean + torch.sqrt(var) * tau * torch.randn_like(x)
                else:   x = mean
            xs.append(x); ys.append(np.full(n_, cls, dtype=int))
        return (torch.cat(xs, 0) if xs else torch.empty(0, z_all.shape[1], device=device),
                np.concatenate(ys) if ys else np.array([], dtype=int))

    # match label proportions; generate MORE synthetic for stronger TSTR (e.g., 3×)
    synth_mult = 3
    unique, counts = np.unique(y, return_counts=True)
    prop = counts / counts.sum()
    n_samples = len(y) * synth_mult
    n_per_class = (prop * n_samples).astype(int)
    while n_per_class.sum() < n_samples:
        n_per_class[np.argmax(prop)] += 1

    z_synth, y_synth = sample_latents_cond(n_per_class)
    with torch.no_grad():
        x_synth_std = vae.decode(z_synth).cpu().numpy().astype(np.float32)

    # 6) CORAL + jitter (tuned)
    cov_reg, jitter_sd = 2e-3, 0.003
    def sym_eig(mat, eps=1e-8):
        w, v = np.linalg.eigh((mat + mat.T) * 0.5)
        w = np.clip(w, eps, None)
        return w, v
    mu_r = Xs.mean(0); Cr = np.cov(Xs, rowvar=False) + cov_reg*np.eye(n_features)
    mu_s = x_synth_std.mean(0); Cs = np.cov(x_synth_std, rowvar=False) + cov_reg*np.eye(n_features)
    ws, Vs = sym_eig(Cs); wt, Vt = sym_eig(Cr)
    Cs_inv_sqrt = Vs @ np.diag(1.0/np.sqrt(ws)) @ Vs.T
    Cr_sqrt     = Vt @ np.diag(np.sqrt(wt))     @ Vt.T
    Xs0 = x_synth_std - mu_s
    Xs_whiten  = Xs0 @ Cs_inv_sqrt
    Xs_coral   = Xs_whiten @ Cr_sqrt + mu_r
    Xs_coral  += jitter_sd * np.random.randn(*Xs_coral.shape).astype(Xs_coral.dtype)
    synth_std_coral  = Xs_coral.astype(np.float32)

    # 7) Class-conditional marginal calibration + blend (keeps TSTR stronger)
    blend = auto_blend_cc(Xs, y, synth_std_coral, y_synth, feat_names)
    synth_std = blend["X"]; alpha_used = blend["alpha"]

    # Save CSV (inverse-scaled) with labels for TSTR
    synth_df = pd.DataFrame(scaler.inverse_transform(synth_std), columns=feat_names)
    synth_df['target'] = y_synth
    csv_path = f"{OUTDIR}/synthetic_wbcd_with_labels.csv"; synth_df.to_csv(csv_path, index=False)

    # 8) Fidelity metrics (Table II)
    MDE_df, MDE_mean_pct = marginal_KS_percent(Xs, synth_std, feat_names)
    P_corr_df, P_corr_mean_pct = pairwise_corr_error_percent(Xs, synth_std, feat_names)
    table2 = pd.DataFrame([["TabGraphSyn", round(MDE_mean_pct, 2), round(P_corr_mean_pct, 2)]],
                          columns=["Method", "Marginal Distribution Errors (%)", "Pairwise Correlation Errors (%)"])
    save_table_as_image(table2, "TABLE II: Statistical Fidelity — WBCD", f"{OUTDIR}/table2_wbcd.png")

    # 9) UMAP (paper-style)
    umap_path = save_umap_paperstyle(Xs, synth_std, "WBCD", xlim=(-5,30), ylim=(-30,5))

    # 10) Correlation heatmaps (with labels)
    corr_path = save_corr_heatmaps_with_labels(Xs, synth_std, feat_names, "WBCD")

    # 11) TSTR: train on synthetic, test on real (Logistic) — report Accuracy, F1, AUC
    clf_tstr = LogisticRegression(max_iter=5000, solver="lbfgs")
    clf_tstr.fit(synth_std, y_synth)
    proba_val = clf_tstr.predict_proba(X_val)[:, 1]
    y_pred = (proba_val >= 0.5).astype(int)
    acc_tstr = accuracy_score(y_val, y_pred)
    f1_tstr  = f1_score(y_val, y_pred)
    auc_tstr = roc_auc_score(y_val, proba_val)
    table4 = pd.DataFrame([["TabGraphSyn", f"{acc_tstr:.4f}", f"{f1_tstr:.4f}", f"{auc_tstr:.4f}"]],
                          columns=["Method", "Accuracy", "F1 Score", "AUC"])
    save_table_as_image(table4, "TABLE IV: TSTR (WBCD — Train Synthetic, Test Real)", f"{OUTDIR}/table4_wbcd_tstr.png")

    # 12) Detection score (accuracy ↑) — Logistic and Best-of
    det_log_acc, det_log_auc = detection_score_logistic(Xs, synth_std)
    det_best_acc, det_best_auc, det_best_name = detection_score_bestof(Xs, synth_std, seed=SEED)
    table6a = pd.DataFrame([["TabGraphSyn (LogReg)", f"{det_log_acc:.2f}%", f"{det_log_auc:.3f}"]],
                           columns=["Model", "Detection Acc (%)", "AUC"])
    save_table_as_image(table6a, "TABLE VI-A: Detection (Logistic)", f"{OUTDIR}/table6_wbcd_detection_logreg.png")
    table6b = pd.DataFrame([[f"Best-of ({det_best_name})", f"{det_best_acc:.2f}%", f"{det_best_auc:.3f}"]],
                           columns=["Model", "Detection Acc (%)", "AUC"])
    save_table_as_image(table6b, "TABLE VI-B: Detection (Best-of)", f"{OUTDIR}/table6_wbcd_detection_bestof.png")

    # Console summary
    print("\n=== WBCD: FINAL METRICS (after optimizations + CC marginal blend) ===")
    print(pd.Series({
        "Blend alpha (0=CORAL, 1=CC-QM)": alpha_used,
        "Marginal KS mean %": MDE_mean_pct,
        "Pairwise |Δρ| mean %": P_corr_mean_pct,
        "TSTR Accuracy": acc_tstr,
        "TSTR F1": f1_tstr,
        "TSTR AUC": auc_tstr,
        "Detection Acc (%) — Logistic": det_log_acc,
        "Detection AUC — Logistic": det_log_auc,
        "Detection Acc (%) — Best-of": det_best_acc,
        "Detection AUC — Best-of": det_best_auc,
        "Best-of attacker": det_best_name,
    }))
    print("Saved:")
    print(" • CSV:", csv_path)
    print(" • UMAP:", umap_path)
    print(" • Heatmaps:", corr_path)
    print(" • Table II:", f"{OUTDIR}/table2_wbcd.png")
    print(" • TSTR (Table IV):", f"{OUTDIR}/table4_wbcd_tstr.png")
    print(" • Detection Logistic:", f"{OUTDIR}/table6_wbcd_detection_logreg.png")
    print(" • Detection Best-of :", f"{OUTDIR}/table6_wbcd_detection_bestof.png")

# ---------------------------
# Run
# ---------------------------
if __name__ == "__main__":
    run_pipeline()

Using device: cpu


  warn(



=== WBCD: FINAL METRICS (after optimizations + CC marginal blend) ===
Blend alpha (0=CORAL, 1=CC-QM)          1.0
Marginal KS mean %                 0.322203
Pairwise |Δρ| mean %               9.950451
TSTR Accuracy                      0.964912
TSTR F1                            0.972603
TSTR AUC                           0.996032
Detection Acc (%) — Logistic      73.684211
Detection AUC — Logistic           0.770699
Detection Acc (%) — Best-of       82.427158
Detection AUC — Best-of            0.905876
Best-of attacker                    SVM-RBF
dtype: object
Saved:
 • CSV: /data/user/home/rkhan5/Personal/outputs/synthetic_wbcd_with_labels.csv
 • UMAP: /data/user/home/rkhan5/Personal/outputs/umap_wbcd_paperstyle.png
 • Heatmaps: /data/user/home/rkhan5/Personal/outputs/corr_heatmaps_wbcd_labeled.png
 • Table II: /data/user/home/rkhan5/Personal/outputs/table2_wbcd.png
 • TSTR (Table IV): /data/user/home/rkhan5/Personal/outputs/table4_wbcd_tstr.png
 • Detection Logistic: /data/user/hom