In [None]:
# ============================================================
# Reviewer-oriented package (FULL VERSION + PLOT DATA EXPORT):
#  1) KM by Stage Bin (ALL internal + optional external)
#  2) KM by Imaging Risk (ALL internal + optional external, external uses INTERNAL cutoff)
#  3) KM 4-group: Stage Bin × Imaging Risk (ALL internal + optional external)
#  4) Cox UV/MV (ALL internal): Age, pathology, stage_bin, imaging_risk
#
# PLUS:
#  - Save plotting-ready CSVs so you can redraw/modify figures easily.
#
# Representative model: file04/run06
# stage0 coding: 1,2,3,4  -> stage_bin: (1-2) vs (3-4)
#
# UPDATE (2026-02-17):
#  - Cox imaging_risk scaling changed to HR per 0.1 increase (default),
#    instead of "per 1-SD".
#
# UPDATE (2026-02-18):
#  - Export FULL roster (all columns) with imaging_risk, risk_group, cutoff
# ============================================================

import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import joblib
import matplotlib.pyplot as plt

from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
from lifelines import CoxPHFitter

# -------------------------
# CONFIG
# -------------------------
BASE_GROUP = "beit0"
GROUP = "n7_30_30"

REP_FILE_IDX = 4
REP_RUN_IDX  = 6

# imaging risk is computed from this representative feature set's model
REP_FEATURE_FOR_IMAGING = "Image only"

T0 = 60  # months horizon for risk = 1 - S(T0)

# ✅ NEW: Cox imaging_risk unit (HR per this increase in original risk scale)
RISK_UNIT_FOR_COX = 0.1

DEVICE = torch.device("cpu")
BASE_SEED = 20250903

MODEL_DIR_ROOT = f"./survival_model/mixture_non_fix/models/{BASE_GROUP}/{GROUP}"
MODEL_DIR_I    = os.path.join(MODEL_DIR_ROOT, f"file{REP_FILE_IDX:02d}")

DATA_CSV_PATH     = f"./deephit/{BASE_GROUP}/test/dl0/{GROUP}/dh11_run{REP_FILE_IDX:02d}.csv"
EXTERNAL_CSV_PATH = f"./external/external{REP_FILE_IDX}.csv"

SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/non_nest/{BASE_GROUP}/results/reviewer_km_pack/dl0/{GROUP}"
SAVE_ROOT_I    = os.path.join(SAVE_ROOT_BASE, f"file{REP_FILE_IDX:02d}_run{REP_RUN_IDX:02d}")
os.makedirs(SAVE_ROOT_I, exist_ok=True)

# feature columns (as you used)
IMG_COLS  = ["feat_436", "feat_519"]
CONT_COLS = ["Age"]
CAT_COLS  = ["pathology", "stage0"]

# -------------------------
# Utils
# -------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def ensure_time_event(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "event" not in df.columns and "survival" in df.columns:
        df["event"] = df["survival"].astype(int)
    if "time" not in df.columns and "fu_date" in df.columns:
        df["time"] = df["fu_date"].astype(np.float32)

    df["time"]  = pd.to_numeric(df["time"], errors="coerce")
    df["event"] = pd.to_numeric(df["event"], errors="coerce").fillna(0).astype(int)
    return df

def add_stage_bin_12_vs_34(df: pd.DataFrame) -> pd.DataFrame:
    """
    stage0 coding: 1,2,3,4
    stage_bin:
      0 => stage0 in {1,2}  (IB–IIIC1)
      1 => stage0 in {3,4}  (IIIC2–IVB)
    """
    df = df.copy()
    df["stage0"] = pd.to_numeric(df["stage0"], errors="coerce")
    df["stage_bin"] = np.where(df["stage0"].isin([1,2]), 0,
                        np.where(df["stage0"].isin([3,4]), 1, np.nan))
    df["stage_bin_label"] = df["stage_bin"].map({
        0: "IB–IIIC1 (stage0 1–2)",
        1: "IIIC2–IVB (stage0 3–4)"
    })
    return df

# -------------------------
# PLOT DATA EXPORT helpers
# -------------------------
def save_km_long_data(df, time_col, event_col, group_col, out_csv):
    """
    Save raw long data for KM re-plotting.
    columns: time, event, group
    """
    d = df[[time_col, event_col, group_col]].copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    d = d.rename(columns={time_col: "time", event_col: "event", group_col: "group"})
    d.to_csv(out_csv, index=False)
    print(f"✅ KM raw(long) saved: {out_csv}")

def save_risk_long_data(df_time_event_risk, out_csv):
    """
    Save raw data for risk-based KM re-plotting.
    columns: time, event, risk
    """
    d = df_time_event_risk.copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna(subset=["time","event","risk"])
    d[["time","event","risk"]].to_csv(out_csv, index=False)
    print(f"✅ Risk raw saved: {out_csv}")

def save_design_matrix_for_cox(df_design, out_csv):
    """
    Save Cox design matrix as-is (including one-hot columns).
    """
    d = df_design.copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    d.to_csv(out_csv, index=False)
    print(f"✅ Cox design matrix saved: {out_csv}")

# -------------------------
# Model + load/predict
# -------------------------
class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a   = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

def load_model_and_ct(model_dir, run_idx, label, device=DEVICE, num_components=2):
    tag = f"run{run_idx:02d}_{label.replace(' ', '_')}"
    ckpt_path = os.path.join(model_dir, f"best_model_{tag}.pt")
    ct_path   = os.path.join(model_dir, f"ct_{tag}.joblib")

    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Missing model: {ckpt_path}")
    if not os.path.exists(ct_path):
        raise FileNotFoundError(f"Missing ct: {ct_path}")

    ckpt = torch.load(ckpt_path, map_location=device)
    input_dim = int(ckpt["input_dim"])

    model = MixtureStretchedExponentialSurvival(input_dim=input_dim, num_components=num_components).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    ct = joblib.load(ct_path)
    return model, ct

@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S   = torch.sum(pi * S_k, dim=1)
        surv.append(S.detach().cpu().numpy())
    return np.vstack(surv)

# -------------------------
# KM plotting helpers
# -------------------------
def km_plot_two_groups(df, time_col, event_col, group_col, group_order, title, out_png, out_summary_csv):
    d = df[[time_col, event_col, group_col]].copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    d = d[d[group_col].isin(group_order)].copy()
    if d[group_col].nunique() < 2:
        print(f"⚠️ Need >=2 groups for KM: {out_png}")
        return

    summ = (d.groupby(group_col)
              .agg(n=(group_col,"size"), events=(event_col,"sum"))
              .reindex(group_order)
              .reset_index())
    summ.to_csv(out_summary_csv, index=False)

    kmf = KaplanMeierFitter()
    plt.figure(figsize=(7,5))
    for g in group_order:
        sub = d[d[group_col] == g]
        kmf.fit(sub[time_col], event_observed=sub[event_col], label=f"{g} (n={len(sub)})")
        kmf.plot(ci_show=True)

    a = d[d[group_col] == group_order[0]]
    b = d[d[group_col] == group_order[1]]
    lr = logrank_test(a[time_col], b[time_col], event_observed_A=a[event_col], event_observed_B=b[event_col])
    pval = lr.p_value

    plt.title(f"{title}\nlog-rank p={pval:.3g}")
    plt.xlabel("Time (Months)")
    plt.ylabel("Survival probability")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"✅ KM saved: {out_png}")

def km_plot_by_risk(df_time_event_risk, title, out_png, out_summary_csv, fixed_cut=None):
    df = df_time_event_risk.copy()
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["time","event","risk"])
    if len(df) < 10:
        print(f"⚠️ Too few samples for KM risk: {out_png}")
        return None

    cut = float(fixed_cut) if fixed_cut is not None else float(np.nanmedian(df["risk"].values))
    df["risk_group"] = np.where(df["risk"] >= cut, "High", "Low")

    summ = (df.groupby("risk_group")
              .agg(n=("risk_group","size"), events=("event","sum"), median_risk=("risk","median"))
              .reset_index())
    summ["cut_used"] = cut
    summ.to_csv(out_summary_csv, index=False)

    kmf = KaplanMeierFitter()
    plt.figure(figsize=(7,5))
    for g in ["Low","High"]:
        sub = df[df["risk_group"] == g]
        kmf.fit(sub["time"], event_observed=sub["event"], label=f"{g} (n={len(sub)})")
        kmf.plot(ci_show=True)

    low  = df[df["risk_group"]=="Low"]
    high = df[df["risk_group"]=="High"]
    lr = logrank_test(low["time"], high["time"], event_observed_A=low["event"], event_observed_B=high["event"])
    pval = lr.p_value

    plt.title(f"{title}\ncutoff={cut:.4f} | log-rank p={pval:.3g}")
    plt.xlabel("Time (Months)")
    plt.ylabel("Survival probability")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"✅ KM saved: {out_png}")
    return cut

def km_plot_four_groups(df, time_col, event_col, group_col, title, out_png, out_summary_csv):
    d = df[[time_col,event_col,group_col]].copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    if d[group_col].nunique() < 2:
        print(f"⚠️ Need >=2 groups for KM: {out_png}")
        return

    summ = (d.groupby(group_col)
              .agg(n=(group_col,"size"), events=(event_col,"sum"))
              .reset_index()
              .sort_values(group_col))
    summ.to_csv(out_summary_csv, index=False)

    kmf = KaplanMeierFitter()
    plt.figure(figsize=(7,5))
    for g in sorted(d[group_col].unique()):
        sub = d[d[group_col] == g]
        kmf.fit(sub[time_col], event_observed=sub[event_col], label=f"{g} (n={len(sub)})")
        kmf.plot(ci_show=True)

    plt.title(title)
    plt.xlabel("Time (Months)")
    plt.ylabel("Survival probability")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"✅ KM saved: {out_png}")

# -------------------------
# Cox helpers (ALL internal)
#   ✅ imaging_risk scaled to HR per 0.1 increase
# -------------------------
def build_design_for_cox(df, use_stage="stage_bin", risk_unit=RISK_UNIT_FOR_COX):
    """
    If risk_unit=0.1:
      design uses imaging_risk_perunit = imaging_risk / 0.1
      => HR corresponds to +0.1 increase in original imaging_risk.
    """
    d = df.copy()
    d["Age"] = pd.to_numeric(d["Age"], errors="coerce")
    d["imaging_risk"] = pd.to_numeric(d["imaging_risk"], errors="coerce")

    if use_stage == "stage0":
        d["stage0"] = pd.to_numeric(d["stage0"], errors="coerce")
        stage_col = "stage0"
    else:
        d["stage_bin"] = pd.to_numeric(d["stage_bin"], errors="coerce")
        stage_col = "stage_bin"

    ru = float(risk_unit) if risk_unit is not None else 0.1
    if ru <= 0:
        ru = 0.1
    d["imaging_risk_perunit"] = d["imaging_risk"] / ru

    pat = d["pathology"].astype("category")
    pat_oh = pd.get_dummies(pat, prefix="pathology", drop_first=True)

    out = pd.concat(
        [d[["time","event","Age",stage_col,"imaging_risk_perunit"]].rename(columns={"imaging_risk_perunit":"imaging_risk"}),
         pat_oh],
        axis=1
    )
    out = out.replace([np.inf, -np.inf], np.nan).dropna()
    return out

def run_cox_uv_mv(df_design, out_prefix):
    time_col, event_col = "time","event"
    covars = [c for c in df_design.columns if c not in [time_col,event_col]]

    # UV
    uv_rows = []
    for c in covars:
        cph = CoxPHFitter()
        try:
            cph.fit(df_design[[time_col,event_col,c]], duration_col=time_col, event_col=event_col)
            s = cph.summary.loc[c]
            uv_rows.append({
                "variable": c,
                "coef": float(s["coef"]),
                "HR": float(s["exp(coef)"]),
                "p": float(s["p"]),
                "HR_95low": float(s["exp(coef) lower 95%"]),
                "HR_95high": float(s["exp(coef) upper 95%"]),
                "n": int(cph._n_examples)
            })
        except Exception as e:
            uv_rows.append({"variable": c, "error": str(e)})

    uv_path = f"{out_prefix}_UV.csv"
    pd.DataFrame(uv_rows).to_csv(uv_path, index=False)

    # MV
    mv_path = f"{out_prefix}_MV.csv"
    cph = CoxPHFitter()
    try:
        cph.fit(df_design, duration_col=time_col, event_col=event_col)
        mv = cph.summary.reset_index().rename(columns={"index":"variable"})
        mv.to_csv(mv_path, index=False)
    except Exception as e:
        pd.DataFrame([{"error": str(e)}]).to_csv(mv_path, index=False)

    print(f"✅ Cox saved: {uv_path} , {mv_path}")

# ============================================================
# MAIN
# ============================================================
print("============================================")
print(f"[Reviewer KM Pack] file{REP_FILE_IDX:02d} run{REP_RUN_IDX:02d} | T0={T0}")
print("MODEL_DIR_I =", MODEL_DIR_I)
print("DATA_CSV    =", DATA_CSV_PATH)
print("EXTERNAL    =", EXTERNAL_CSV_PATH)
print("OUT_DIR     =", SAVE_ROOT_I)
print("============================================")

if not os.path.exists(DATA_CSV_PATH):
    raise FileNotFoundError(f"Missing internal data CSV: {DATA_CSV_PATH}")

df_all = ensure_time_event(pd.read_csv(DATA_CSV_PATH))
df_all = add_stage_bin_12_vs_34(df_all)

df_ext = None
if os.path.exists(EXTERNAL_CSV_PATH):
    df_ext = ensure_time_event(pd.read_csv(EXTERNAL_CSV_PATH))
    df_ext = add_stage_bin_12_vs_34(df_ext)

# -------------------------
# 1) KM by Stage Bin (ALL internal)
# -------------------------
stage_groups = ["IB–IIIC1 (stage0 1–2)", "IIIC2–IVB (stage0 3–4)"]
df_stage_int = df_all[["time","event","stage_bin_label"]].dropna(subset=["stage_bin_label"])

save_km_long_data(
    df_stage_int, "time", "event", "stage_bin_label",
    os.path.join(SAVE_ROOT_I, "DATA_INTERNAL_ALL_stagebin_long.csv")
)

out_png = os.path.join(SAVE_ROOT_I, "KM_INTERNAL_ALL_stagebin.png")
out_sum = os.path.join(SAVE_ROOT_I, "KM_INTERNAL_ALL_stagebin_summary.csv")
km_plot_two_groups(
    df_stage_int, "time","event","stage_bin_label", stage_groups,
    title=f"KM Internal (ALL) by Stage Bin (1–2 vs 3–4) | file{REP_FILE_IDX:02d}",
    out_png=out_png, out_summary_csv=out_sum
)

if df_ext is not None:
    df_stage_ext = df_ext[["time","event","stage_bin_label"]].dropna(subset=["stage_bin_label"])

    save_km_long_data(
        df_stage_ext, "time", "event", "stage_bin_label",
        os.path.join(SAVE_ROOT_I, "DATA_EXTERNAL_stagebin_long.csv")
    )

    out_png = os.path.join(SAVE_ROOT_I, "KM_EXTERNAL_stagebin.png")
    out_sum = os.path.join(SAVE_ROOT_I, "KM_EXTERNAL_stagebin_summary.csv")
    km_plot_two_groups(
        df_stage_ext, "time","event","stage_bin_label", stage_groups,
        title=f"KM External by Stage Bin (1–2 vs 3–4) | external{REP_FILE_IDX:02d}",
        out_png=out_png, out_summary_csv=out_sum
    )

# -------------------------
# 2) KM by Imaging Risk (ALL internal) using representative model
# -------------------------
model, ct = load_model_and_ct(MODEL_DIR_I, REP_RUN_IDX, REP_FEATURE_FOR_IMAGING, device=DEVICE, num_components=2)

used_cols = IMG_COLS + CONT_COLS + CAT_COLS
missing_int = [c for c in used_cols if c not in df_all.columns]
if missing_int:
    raise ValueError(f"Internal missing required columns for imaging risk: {missing_int}")

# ---- INTERNAL risk
X_int_df = df_all[used_cols].copy()
X_int = ct.transform(X_int_df)
X_int = pd.DataFrame(X_int).fillna(0).values.astype(np.float32)
X_int_t = torch.tensor(X_int, dtype=torch.float32).to(DEVICE)

S_int_T0 = predict_survival(model, X_int_t, [T0])[0, :]
risk_int = 1.0 - S_int_T0

df_all = df_all.copy()
df_all["imaging_risk"] = risk_int

# Risk-only plot data
df_risk_int = df_all[["time","event","imaging_risk"]].rename(columns={"imaging_risk":"risk"}).dropna()
save_risk_long_data(
    df_risk_int,
    os.path.join(SAVE_ROOT_I, f"DATA_INTERNAL_ALL_imagingRisk_T{T0}.csv")
)

out_png = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_imagingRisk_T{T0}.png")
out_sum = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_imagingRisk_T{T0}_summary.csv")
internal_cut = km_plot_by_risk(
    df_risk_int,
    title=f"KM Internal (ALL) by Imaging Risk | {REP_FEATURE_FOR_IMAGING}\nrisk=1-S({T0}) | file{REP_FILE_IDX:02d} run{REP_RUN_IDX:02d}",
    out_png=out_png, out_summary_csv=out_sum, fixed_cut=None
)

# ✅ NEW: add High/Low group + cutoff into df_all, and export FULL roster
df_all["imaging_risk_cutoff"] = internal_cut
df_all["imaging_risk_group"] = np.where(df_all["imaging_risk"] >= internal_cut, "HighRisk", "LowRisk")

# ✅ FULL internal roster export (ALL columns)
out_full_internal = os.path.join(SAVE_ROOT_I, f"INTERNAL_ALL_FULL_with_imagingRisk_T{T0}.csv")
df_all.to_csv(out_full_internal, index=False)
print(f"✅ FULL internal roster saved: {out_full_internal}")

# Save risk_group table for easy plotting (High/Low with cutoff)
df_risk_int2 = df_risk_int.copy()
df_risk_int2["cutoff"] = internal_cut
df_risk_int2["risk_group"] = np.where(df_risk_int2["risk"] >= internal_cut, "High", "Low")
df_risk_int2.to_csv(
    os.path.join(SAVE_ROOT_I, f"DATA_INTERNAL_ALL_imagingRisk_T{T0}_withGroup.csv"),
    index=False
)

# ---- EXTERNAL risk (optional) + FULL roster export
if df_ext is not None:
    missing_ext = [c for c in used_cols if c not in df_ext.columns]
    if missing_ext:
        print("⚠️ External missing columns for imaging risk KM:", missing_ext)
    else:
        X_ext_df = df_ext[used_cols].copy()
        X_ext = ct.transform(X_ext_df)
        X_ext = pd.DataFrame(X_ext).fillna(0).values.astype(np.float32)
        X_ext_t = torch.tensor(X_ext, dtype=torch.float32).to(DEVICE)

        S_ext_T0 = predict_survival(model, X_ext_t, [T0])[0, :]
        risk_ext = 1.0 - S_ext_T0

        df_ext = df_ext.copy()
        df_ext["imaging_risk"] = risk_ext

        df_risk_ext = df_ext[["time","event","imaging_risk"]].rename(columns={"imaging_risk":"risk"}).dropna()
        save_risk_long_data(
            df_risk_ext,
            os.path.join(SAVE_ROOT_I, f"DATA_EXTERNAL_imagingRisk_T{T0}.csv")
        )

        out_png = os.path.join(SAVE_ROOT_I, f"KM_EXTERNAL_imagingRisk_T{T0}.png")
        out_sum = os.path.join(SAVE_ROOT_I, f"KM_EXTERNAL_imagingRisk_T{T0}_summary.csv")
        km_plot_by_risk(
            df_risk_ext,
            title=f"KM External by Imaging Risk (INTERNAL cutoff) | risk=1-S({T0})",
            out_png=out_png, out_summary_csv=out_sum, fixed_cut=internal_cut
        )

        # ✅ NEW: add group + cutoff + export FULL external roster
        df_ext["imaging_risk_cutoff"] = internal_cut
        df_ext["imaging_risk_group"] = np.where(df_ext["imaging_risk"] >= internal_cut, "HighRisk", "LowRisk")

        out_full_external = os.path.join(SAVE_ROOT_I, f"EXTERNAL_FULL_with_imagingRisk_T{T0}.csv")
        df_ext.to_csv(out_full_external, index=False)
        print(f"✅ FULL external roster saved: {out_full_external}")

        df_risk_ext2 = df_risk_ext.copy()
        df_risk_ext2["cutoff"] = internal_cut
        df_risk_ext2["risk_group"] = np.where(df_risk_ext2["risk"] >= internal_cut, "High", "Low")
        df_risk_ext2.to_csv(
            os.path.join(SAVE_ROOT_I, f"DATA_EXTERNAL_imagingRisk_T{T0}_withGroup.csv"),
            index=False
        )

# -------------------------
# 3) 4-group KM: Stage Bin × Imaging Risk (ALL internal)
# -------------------------
df4 = df_all[["time","event","stage_bin_label","stage_bin","imaging_risk"]].dropna(subset=["stage_bin","imaging_risk"]).copy()
df4["risk_group"] = np.where(df4["imaging_risk"] >= internal_cut, "HighRisk", "LowRisk")
df4["group4"] = df4["stage_bin_label"] + " | " + df4["risk_group"]

save_km_long_data(
    df4, "time", "event", "group4",
    os.path.join(SAVE_ROOT_I, f"DATA_INTERNAL_ALL_stagebin_x_imagingRisk_T{T0}_long.csv")
)

out_png = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_stagebin_x_imagingRisk_T{T0}.png")
out_sum = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_stagebin_x_imagingRisk_T{T0}_summary.csv")
km_plot_four_groups(
    df4, "time","event","group4",
    title=f"KM Internal (ALL) | Stage Bin × Imaging Risk\ncutoff(from internal median risk) | T0={T0}",
    out_png=out_png, out_summary_csv=out_sum
)

if df_ext is not None and "imaging_risk" in df_ext.columns:
    df4e = df_ext[["time","event","stage_bin_label","stage_bin","imaging_risk"]].dropna(subset=["stage_bin","imaging_risk"]).copy()
    df4e["risk_group"] = np.where(df4e["imaging_risk"] >= internal_cut, "HighRisk", "LowRisk")
    df4e["group4"] = df4e["stage_bin_label"] + " | " + df4e["risk_group"]

    save_km_long_data(
        df4e, "time", "event", "group4",
        os.path.join(SAVE_ROOT_I, f"DATA_EXTERNAL_stagebin_x_imagingRisk_T{T0}_long.csv")
    )

    out_png = os.path.join(SAVE_ROOT_I, f"KM_EXTERNAL_stagebin_x_imagingRisk_T{T0}.png")
    out_sum = os.path.join(SAVE_ROOT_I, f"KM_EXTERNAL_stagebin_x_imagingRisk_T{T0}_summary.csv")
    km_plot_four_groups(
        df4e, "time","event","group4",
        title=f"KM External | Stage Bin × Imaging Risk (internal cutoff) | T0={T0}",
        out_png=out_png, out_summary_csv=out_sum
    )

# -------------------------
# 4) Cox UV/MV (ALL internal): Age, pathology, stage_bin, imaging_risk
# -------------------------
df_cox = df_all[["time","event","Age","pathology","stage0","stage_bin","imaging_risk"]].copy()
df_cox = df_cox.dropna(subset=["time","event","Age","pathology","stage_bin","imaging_risk"])

# (a) stage_bin model
design_bin = build_design_for_cox(df_cox, use_stage="stage_bin", risk_unit=RISK_UNIT_FOR_COX)
save_design_matrix_for_cox(
    design_bin,
    os.path.join(SAVE_ROOT_I, f"DATA_CoxDesign_INTERNAL_T{T0}_stageBIN_imgrisk_per0p1.csv")
)
out_prefix = os.path.join(SAVE_ROOT_I, f"Cox_INTERNAL_ALL_T{T0}_stageBIN_imgrisk_per0p1")
run_cox_uv_mv(design_bin, out_prefix)

# (b) stage0 ordinal (optional)
design_s0 = build_design_for_cox(df_cox, use_stage="stage0", risk_unit=RISK_UNIT_FOR_COX)
save_design_matrix_for_cox(
    design_s0,
    os.path.join(SAVE_ROOT_I, f"DATA_CoxDesign_INTERNAL_T{T0}_stage0_imgrisk_per0p1.csv")
)
out_prefix = os.path.join(SAVE_ROOT_I, f"Cox_INTERNAL_ALL_T{T0}_stage0_imgrisk_per0p1")
run_cox_uv_mv(design_s0, out_prefix)

print("\n✅ DONE. Output directory:")
print(SAVE_ROOT_I)

In [None]:
# ============================================================
# Reviewer-oriented package (FULL VERSION + PLOT DATA EXPORT):
#  1) KM by Stage Bin (ALL internal)
#  2) KM by Imaging Risk (ALL internal)
#  3) KM 4-group: Stage Bin × Imaging Risk (ALL internal)
#  4) Cox UV/MV (ALL internal):
#       - stage : categorical (SAFE, one-hot)
#       - imaging_risk : continuous (HR per 0.1 increase) ✅
#
# NOTE:
#  - This script assumes PFS-like endpoint:
#      event <- recur
#      time  <- recur_date
#
# UPDATE (2026-02-17):
#  - Cox imaging_risk scaling changed from z-score (per 1-SD)
#    to per 0.1 increase by default.
# ============================================================

import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import joblib
import matplotlib.pyplot as plt

from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
from lifelines import CoxPHFitter

# -------------------------
# CONFIG
# -------------------------
BASE_GROUP = "beit0"
GROUP = "n7_30_30"

REP_FILE_IDX = 4
REP_RUN_IDX  = 6

REP_FEATURE_FOR_IMAGING = "Image only"
T0 = 60  # months horizon for risk = 1 - S(T0)

DEVICE = torch.device("cpu")
BASE_SEED = 20250903

MODEL_DIR_ROOT = f"./survival_model/mixture_non_fix/models/{BASE_GROUP}/{GROUP}"
MODEL_DIR_I    = os.path.join(MODEL_DIR_ROOT, f"file{REP_FILE_IDX:02d}")

DATA_CSV_PATH     = f"./deephit/{BASE_GROUP}/test/dl0/{GROUP}/dh11_run{REP_FILE_IDX:02d}.csv"

SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/non_nest/{BASE_GROUP}/results/reviewer_km_pack_pfs/dl0/{GROUP}"
SAVE_ROOT_I    = os.path.join(SAVE_ROOT_BASE, f"file{REP_FILE_IDX:02d}_run{REP_RUN_IDX:02d}")
os.makedirs(SAVE_ROOT_I, exist_ok=True)

# feature columns (as you used)
IMG_COLS  = ["feat_436", "feat_519"]
CONT_COLS = ["Age"]
CAT_COLS  = ["pathology", "stage0"]

# -------------------------
# Utils
# -------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def ensure_time_event(df: pd.DataFrame) -> pd.DataFrame:
    """
    PFS-like:
      - event <- recur
      - time  <- recur_date
    """
    df = df.copy()

    if "event" not in df.columns and "recur" in df.columns:
        df["event"] = df["recur"].astype(int)

    if "time" not in df.columns and "recur_date" in df.columns:
        df["time"] = df["recur_date"].astype(np.float32)

    if "time" in df.columns:
        df["time"] = pd.to_numeric(df["time"], errors="coerce")

    if "event" in df.columns:
        df["event"] = pd.to_numeric(df["event"], errors="coerce").fillna(0).astype(int)

    return df

def add_stage_bin_12_vs_34(df: pd.DataFrame) -> pd.DataFrame:
    """
    stage0 coding: 1,2,3,4
    stage_bin:
      0 => stage0 in {1,2}  (IB–IIIC1)
      1 => stage0 in {3,4}  (IIIC2–IVB)
    """
    df = df.copy()
    if "stage0" in df.columns:
        df["stage0"] = pd.to_numeric(df["stage0"], errors="coerce")
        df["stage_bin"] = np.where(df["stage0"].isin([1, 2]), 0,
                            np.where(df["stage0"].isin([3, 4]), 1, np.nan))
        df["stage_bin_label"] = df["stage_bin"].map({
            0: "IB–IIIC1 (stage0 1–2)",
            1: "IIIC2–IVB (stage0 3–4)"
        })
    else:
        df["stage_bin"] = np.nan
        df["stage_bin_label"] = np.nan
    return df

# -------------------------
# PLOT DATA EXPORT helpers
# -------------------------
def save_km_long_data(df, time_col, event_col, group_col, out_csv):
    d = df[[time_col, event_col, group_col]].copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    d = d.rename(columns={time_col: "time", event_col: "event", group_col: "group"})
    d.to_csv(out_csv, index=False)
    print(f"✅ KM raw(long) saved: {out_csv}")

def save_risk_long_data(df_time_event_risk, out_csv):
    d = df_time_event_risk.copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna(subset=["time","event","risk"])
    d[["time","event","risk"]].to_csv(out_csv, index=False)
    print(f"✅ Risk raw saved: {out_csv}")

def save_design_matrix_for_cox(df_design, out_csv):
    d = df_design.copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    d.to_csv(out_csv, index=False)
    print(f"✅ Cox design matrix saved: {out_csv}")

# -------------------------
# Model + load/predict
# -------------------------
class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a   = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

def load_model_and_ct(model_dir, run_idx, label, device=DEVICE, num_components=2):
    tag = f"run{run_idx:02d}_{label.replace(' ', '_')}"
    ckpt_path = os.path.join(model_dir, f"best_model_{tag}.pt")
    ct_path   = os.path.join(model_dir, f"ct_{tag}.joblib")

    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Missing model: {ckpt_path}")
    if not os.path.exists(ct_path):
        raise FileNotFoundError(f"Missing ct: {ct_path}")

    ckpt = torch.load(ckpt_path, map_location=device)
    input_dim = int(ckpt["input_dim"])

    model = MixtureStretchedExponentialSurvival(input_dim=input_dim, num_components=num_components).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    ct = joblib.load(ct_path)
    return model, ct

@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S   = torch.sum(pi * S_k, dim=1)
        surv.append(S.detach().cpu().numpy())
    return np.vstack(surv)

# -------------------------
# KM plotting helpers
# -------------------------
def km_plot_two_groups(df, time_col, event_col, group_col, group_order, title, out_png, out_summary_csv):
    d = df[[time_col, event_col, group_col]].copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    d = d[d[group_col].isin(group_order)].copy()
    if d[group_col].nunique() < 2:
        print(f"⚠️ Need >=2 groups for KM: {out_png}")
        return

    summ = (d.groupby(group_col)
              .agg(n=(group_col,"size"), events=(event_col,"sum"))
              .reindex(group_order)
              .reset_index())
    summ.to_csv(out_summary_csv, index=False)

    kmf = KaplanMeierFitter()
    plt.figure(figsize=(7,5))
    for g in group_order:
        sub = d[d[group_col] == g]
        kmf.fit(sub[time_col], event_observed=sub[event_col], label=f"{g} (n={len(sub)})")
        kmf.plot(ci_show=True)

    a = d[d[group_col] == group_order[0]]
    b = d[d[group_col] == group_order[1]]
    lr = logrank_test(a[time_col], b[time_col], event_observed_A=a[event_col], event_observed_B=b[event_col])
    pval = lr.p_value

    plt.title(f"{title}\nlog-rank p={pval:.3g}")
    plt.xlabel("Time (Months)")
    plt.ylabel("Survival probability")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"✅ KM saved: {out_png}")

def km_plot_by_risk(df_time_event_risk, title, out_png, out_summary_csv, fixed_cut=None):
    df = df_time_event_risk.copy()
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["time","event","risk"])
    if len(df) < 10:
        print(f"⚠️ Too few samples for KM risk: {out_png}")
        return None

    cut = float(fixed_cut) if fixed_cut is not None else float(np.nanmedian(df["risk"].values))
    df["risk_group"] = np.where(df["risk"] >= cut, "High", "Low")

    summ = (df.groupby("risk_group")
              .agg(n=("risk_group","size"), events=("event","sum"), median_risk=("risk","median"))
              .reset_index())
    summ["cut_used"] = cut
    summ.to_csv(out_summary_csv, index=False)

    kmf = KaplanMeierFitter()
    plt.figure(figsize=(7,5))
    for g in ["Low","High"]:
        sub = df[df["risk_group"] == g]
        kmf.fit(sub["time"], event_observed=sub["event"], label=f"{g} (n={len(sub)})")
        kmf.plot(ci_show=True)

    low  = df[df["risk_group"]=="Low"]
    high = df[df["risk_group"]=="High"]
    lr = logrank_test(low["time"], high["time"], event_observed_A=low["event"], event_observed_B=high["event"])
    pval = lr.p_value

    plt.title(f"{title}\ncutoff={cut:.4f} | log-rank p={pval:.3g}")
    plt.xlabel("Time (Months)")
    plt.ylabel("Survival probability")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"✅ KM saved: {out_png}")
    return cut

def km_plot_four_groups(df, time_col, event_col, group_col, title, out_png, out_summary_csv):
    d = df[[time_col,event_col,group_col]].copy()
    d = d.replace([np.inf, -np.inf], np.nan).dropna()
    if d[group_col].nunique() < 2:
        print(f"⚠️ Need >=2 groups for KM: {out_png}")
        return

    summ = (d.groupby(group_col)
              .agg(n=(group_col,"size"), events=(event_col,"sum"))
              .reset_index()
              .sort_values(group_col))
    summ.to_csv(out_summary_csv, index=False)

    kmf = KaplanMeierFitter()
    plt.figure(figsize=(7,5))
    for g in sorted(d[group_col].unique()):
        sub = d[d[group_col] == g]
        kmf.fit(sub[time_col], event_observed=sub[event_col], label=f"{g} (n={len(sub)})")
        kmf.plot(ci_show=True)

    plt.title(title)
    plt.xlabel("Time (Months)")
    plt.ylabel("Survival probability")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"✅ KM saved: {out_png}")

# -------------------------
# Cox helpers (stage categorical + imaging_risk continuous)
#   ✅ imaging_risk scaled to HR per 0.1 increase
# -------------------------
def build_design_for_cox(
    df: pd.DataFrame,
    stage_mode: str = "stage0_cat",   # "stage0_cat" | "stage_bin_cat" | "stage_bin_num"
    risk_unit: float = 0.1,           # ✅ HR per 0.1 increase (default)
    standardize_risk: bool = False,   # ✅ z-score OFF by default
    standardize_age: bool = False
) -> pd.DataFrame:
    """
    Cox design:
      - Age: continuous (optional z-score)
      - imaging_risk: continuous
          * if standardize_risk=True -> HR per 1-SD
          * else -> HR per risk_unit (default 0.1) ✅
      - pathology: categorical (one-hot, drop_first)
      - stage: categorical safe (one-hot) ✅
    """
    d = df.copy()

    # numeric casts
    d["Age"] = pd.to_numeric(d["Age"], errors="coerce")
    d["imaging_risk"] = pd.to_numeric(d["imaging_risk"], errors="coerce")
    d["time"] = pd.to_numeric(d["time"], errors="coerce")
    d["event"] = pd.to_numeric(d["event"], errors="coerce").fillna(0).astype(int)

    # required columns by stage mode
    req = ["time","event","Age","imaging_risk","pathology"]
    if stage_mode == "stage0_cat":
        req += ["stage0"]
        d["stage0"] = pd.to_numeric(d["stage0"], errors="coerce")
    else:
        req += ["stage_bin"]
        d["stage_bin"] = pd.to_numeric(d["stage_bin"], errors="coerce")

    d = d[req].replace([np.inf, -np.inf], np.nan).dropna()

    # imaging_risk continuous scaling
    if standardize_risk:
        mu = d["imaging_risk"].mean()
        sd = d["imaging_risk"].std(ddof=0)
        if sd and sd > 0 and not np.isnan(sd):
            d["imaging_risk_cont"] = (d["imaging_risk"] - mu) / sd
        else:
            d["imaging_risk_cont"] = d["imaging_risk"]
    else:
        # ✅ HR per risk_unit increase
        if risk_unit is None or float(risk_unit) <= 0:
            risk_unit = 0.1
        d["imaging_risk_cont"] = d["imaging_risk"] / float(risk_unit)

    # Age scaling optional
    if standardize_age:
        mu = d["Age"].mean()
        sd = d["Age"].std(ddof=0)
        if sd and sd > 0 and not np.isnan(sd):
            d["Age_cont"] = (d["Age"] - mu) / sd
        else:
            d["Age_cont"] = d["Age"]
    else:
        d["Age_cont"] = d["Age"]

    # pathology one-hot
    pat = d["pathology"].astype("category")
    pat_oh = pd.get_dummies(pat, prefix="pathology", drop_first=True)

    # stage encoding
    if stage_mode == "stage0_cat":
        st = d["stage0"].astype("Int64").astype("category")
        stage_enc = pd.get_dummies(st, prefix="stage0", drop_first=True)
    elif stage_mode == "stage_bin_cat":
        st = d["stage_bin"].astype("Int64").astype("category")
        stage_enc = pd.get_dummies(st, prefix="stage_bin", drop_first=True)
    elif stage_mode == "stage_bin_num":
        stage_enc = d[["stage_bin"]].copy()
    else:
        raise ValueError(f"Unknown stage_mode: {stage_mode}")

    core = d[["time","event","Age_cont","imaging_risk_cont"]].copy()
    core = core.rename(columns={"Age_cont":"Age", "imaging_risk_cont":"imaging_risk"})

    out = pd.concat([core, stage_enc, pat_oh], axis=1)
    out = out.replace([np.inf, -np.inf], np.nan).dropna()
    return out

def run_cox_uv_mv(df_design, out_prefix):
    time_col, event_col = "time","event"
    covars = [c for c in df_design.columns if c not in [time_col,event_col]]

    # UV
    uv_rows = []
    for c in covars:
        cph = CoxPHFitter()
        try:
            cph.fit(df_design[[time_col,event_col,c]], duration_col=time_col, event_col=event_col)
            s = cph.summary.loc[c]
            uv_rows.append({
                "variable": c,
                "coef": float(s["coef"]),
                "HR": float(s["exp(coef)"]),
                "p": float(s["p"]),
                "HR_95low": float(s["exp(coef) lower 95%"]),
                "HR_95high": float(s["exp(coef) upper 95%"]),
                "n": int(cph._n_examples)
            })
        except Exception as e:
            uv_rows.append({"variable": c, "error": str(e)})

    uv_path = f"{out_prefix}_UV.csv"
    pd.DataFrame(uv_rows).to_csv(uv_path, index=False)

    # MV
    mv_path = f"{out_prefix}_MV.csv"
    cph = CoxPHFitter()
    try:
        cph.fit(df_design, duration_col=time_col, event_col=event_col)
        mv = cph.summary.reset_index().rename(columns={"index":"variable"})
        mv.to_csv(mv_path, index=False)
    except Exception as e:
        pd.DataFrame([{"error": str(e)}]).to_csv(mv_path, index=False)

    print(f"✅ Cox saved: {uv_path} , {mv_path}")

# ============================================================
# MAIN
# ============================================================
set_seed(BASE_SEED)

print("============================================")
print(f"[Reviewer KM Pack PFS] file{REP_FILE_IDX:02d} run{REP_RUN_IDX:02d} | T0={T0}")
print("MODEL_DIR_I =", MODEL_DIR_I)
print("DATA_CSV    =", DATA_CSV_PATH)
print("OUT_DIR     =", SAVE_ROOT_I)
print("============================================")

if not os.path.exists(DATA_CSV_PATH):
    raise FileNotFoundError(f"Missing internal data CSV: {DATA_CSV_PATH}")

df_all = ensure_time_event(pd.read_csv(DATA_CSV_PATH))
df_all = add_stage_bin_12_vs_34(df_all)

# -------------------------
# 1) KM by Stage Bin (ALL internal)
# -------------------------
stage_groups = ["IB–IIIC1 (stage0 1–2)", "IIIC2–IVB (stage0 3–4)"]
df_stage_int = df_all[["time","event","stage_bin_label"]].dropna(subset=["stage_bin_label"])

save_km_long_data(
    df_stage_int, "time", "event", "stage_bin_label",
    os.path.join(SAVE_ROOT_I, "DATA_INTERNAL_ALL_stagebin_long.csv")
)

out_png = os.path.join(SAVE_ROOT_I, "KM_INTERNAL_ALL_stagebin.png")
out_sum = os.path.join(SAVE_ROOT_I, "KM_INTERNAL_ALL_stagebin_summary.csv")
km_plot_two_groups(
    df_stage_int, "time","event","stage_bin_label", stage_groups,
    title=f"KM Internal (ALL) by Stage Bin (1–2 vs 3–4) | file{REP_FILE_IDX:02d}",
    out_png=out_png, out_summary_csv=out_sum
)

# -------------------------
# 2) KM by Imaging Risk (ALL internal) using representative model
# -------------------------
model, ct = load_model_and_ct(MODEL_DIR_I, REP_RUN_IDX, REP_FEATURE_FOR_IMAGING, device=DEVICE, num_components=2)

used_cols = IMG_COLS + CONT_COLS + CAT_COLS
missing_int = [c for c in used_cols if c not in df_all.columns]
if missing_int:
    raise ValueError(f"Internal missing required columns for imaging risk: {missing_int}")

X_int_df = df_all[used_cols].copy()
X_int = ct.transform(X_int_df)
X_int = pd.DataFrame(X_int).fillna(0).values.astype(np.float32)
X_int_t = torch.tensor(X_int, dtype=torch.float32).to(DEVICE)

S_int_T0 = predict_survival(model, X_int_t, [T0])[0, :]
risk_int = 1.0 - S_int_T0

df_all = df_all.copy()
df_all["imaging_risk"] = risk_int

df_all[["time","event","stage0","stage_bin","stage_bin_label","Age","pathology","imaging_risk"]].to_csv(
    os.path.join(SAVE_ROOT_I, f"INTERNAL_ALL_with_imagingRisk_T{T0}.csv"), index=False
)

df_risk_int = df_all[["time","event","imaging_risk"]].rename(columns={"imaging_risk":"risk"}).dropna()
save_risk_long_data(
    df_risk_int,
    os.path.join(SAVE_ROOT_I, f"DATA_INTERNAL_ALL_imagingRisk_T{T0}.csv")
)

out_png = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_imagingRisk_T{T0}.png")
out_sum = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_imagingRisk_T{T0}_summary.csv")
internal_cut = km_plot_by_risk(
    df_risk_int,
    title=f"KM Internal (ALL) by Imaging Risk | {REP_FEATURE_FOR_IMAGING}\nrisk=1-S({T0}) | file{REP_FILE_IDX:02d} run{REP_RUN_IDX:02d}",
    out_png=out_png, out_summary_csv=out_sum, fixed_cut=None
)

df_risk_int2 = df_risk_int.copy()
df_risk_int2["cutoff"] = internal_cut
df_risk_int2["risk_group"] = np.where(df_risk_int2["risk"] >= internal_cut, "High", "Low")
df_risk_int2.to_csv(
    os.path.join(SAVE_ROOT_I, f"DATA_INTERNAL_ALL_imagingRisk_T{T0}_withGroup.csv"),
    index=False
)

# -------------------------
# 3) 4-group KM: Stage Bin × Imaging Risk (ALL internal)
# -------------------------
df4 = df_all[["time","event","stage_bin_label","stage_bin","imaging_risk"]].dropna(subset=["stage_bin","imaging_risk"]).copy()
df4["risk_group"] = np.where(df4["imaging_risk"] >= internal_cut, "HighRisk", "LowRisk")
df4["group4"] = df4["stage_bin_label"] + " | " + df4["risk_group"]

save_km_long_data(
    df4, "time", "event", "group4",
    os.path.join(SAVE_ROOT_I, f"DATA_INTERNAL_ALL_stagebin_x_imagingRisk_T{T0}_long.csv")
)

out_png = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_stagebin_x_imagingRisk_T{T0}.png")
out_sum = os.path.join(SAVE_ROOT_I, f"KM_INTERNAL_ALL_stagebin_x_imagingRisk_T{T0}_summary.csv")
km_plot_four_groups(
    df4, "time","event","group4",
    title=f"KM Internal (ALL) | Stage Bin × Imaging Risk\ncutoff(from internal median risk) | T0={T0}",
    out_png=out_png, out_summary_csv=out_sum
)

# -------------------------
# 4) Cox UV/MV
#     - stage categorical safe
#     - imaging_risk continuous (HR per 0.1 increase) ✅
# -------------------------
df_cox = df_all[["time","event","Age","pathology","stage0","stage_bin","imaging_risk"]].copy()
df_cox = df_cox.dropna(subset=["time","event","Age","pathology","stage0","stage_bin","imaging_risk"])

# (a) Recommended: stage0 categorical (SAFE)
design_stage0 = build_design_for_cox(
    df_cox,
    stage_mode="stage0_cat",
    risk_unit=0.1,            # ✅ HR per 0.1 increase
    standardize_risk=False,   # ✅ not z-score
    standardize_age=False
)
save_design_matrix_for_cox(
    design_stage0,
    os.path.join(SAVE_ROOT_I, f"DATA_CoxDesign_INTERNAL_T{T0}_stage0CATEG_imgrisk_per0p1.csv")
)
out_prefix = os.path.join(SAVE_ROOT_I, f"Cox_INTERNAL_ALL_T{T0}_stage0CATEG_imgrisk_per0p1")
run_cox_uv_mv(design_stage0, out_prefix)

# (b) Optional: stage_bin categorical
design_stagebin_cat = build_design_for_cox(
    df_cox,
    stage_mode="stage_bin_cat",
    risk_unit=0.1,
    standardize_risk=False,
    standardize_age=False
)
save_design_matrix_for_cox(
    design_stagebin_cat,
    os.path.join(SAVE_ROOT_I, f"DATA_CoxDesign_INTERNAL_T{T0}_stageBINCATEG_imgrisk_per0p1.csv")
)
out_prefix = os.path.join(SAVE_ROOT_I, f"Cox_INTERNAL_ALL_T{T0}_stageBINCATEG_imgrisk_per0p1")
run_cox_uv_mv(design_stagebin_cat, out_prefix)

print("\n✅ DONE. Output directory:")
print(SAVE_ROOT_I)