In [None]:
# KM stratification using DCA-selected model (file04/run06) - NO DCA export
# Mixture Stretched Exponential Survival + Model Saving + (Optional) External Eval + KM plots
#
# Key behavior:
# - Train + save 30 Monte-Carlo models per feature-set (same as before)
# - Then load the representative model (file04, run06) and draw KM plots
# - External KM uses INTERNAL (validation) median cutoff (no external re-median)

import os
import joblib
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
import matplotlib.pyplot as plt
import glob, re

# =========================
# Config
# =========================
BASE_GROUPS = ["beit0"]
GROUP = "n7_30_30"
N_RUNS = 30
time_points = [12, 24, 36, 48, 60, 72]
DEVICE = torch.device("cpu")

# reproducibility
BASE_SEED = 20250903

# model save root
os.makedirs("./survival_model/mixture_non_fix/models", exist_ok=True)

# =========================
# Utils
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

def mixture_stretched_nll(t, e, pi, lam, a, eps=1e-8):
    t = t.view(-1, 1)
    t_a = torch.pow(t + eps, a)
    S_k = torch.exp(-lam * t_a)
    f_k = lam * a * torch.pow(t + eps, a - 1) * S_k
    f = torch.sum(pi * f_k, dim=1) + eps
    S = torch.sum(pi * S_k, dim=1) + eps
    loglik = e * torch.log(f) + (1 - e) * torch.log(S)
    return -loglik.mean()

@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S = torch.sum(pi * S_k, dim=1)
        surv.append(S.cpu().numpy())
    return np.vstack(surv)  # (T, N)

def calc_auc(surv_arr, y_df, times):
    aucs = {}
    for i, t in enumerate(times):
        true = ((y_df["event"] == 1) & (y_df["time"] <= t)).astype(int)
        pred = 1 - surv_arr[i, :]
        try:
            aucs[t] = roc_auc_score(true, pred)
        except Exception:
            aucs[t] = np.nan
    return aucs

def safe_concordance_index(times, risks, events):
    times = np.asarray(times)
    risks = np.asarray(risks)
    events = np.asarray(events)
    mask = ~(np.isnan(times) | np.isnan(risks) | np.isnan(events))
    if np.sum(mask) < 2:
        return np.nan
    if np.std(risks[mask]) < 1e-6:
        return np.nan
    return concordance_index(times[mask], risks[mask], events[mask])

def save_model_and_ct(model_state, ct, save_dir, run_idx, label, input_dim):
    tag = f"run{run_idx:02d}_{label.replace(' ', '_')}"
    os.makedirs(save_dir, exist_ok=True)
    torch.save({"state_dict": model_state, "input_dim": input_dim},
               os.path.join(save_dir, f"best_model_{tag}.pt"))
    joblib.dump(ct, os.path.join(save_dir, f"ct_{tag}.joblib"))

def load_model_and_ct(model_dir, run_idx, label, device=DEVICE, num_components=2):
    tag = f"run{run_idx:02d}_{label.replace(' ', '_')}"
    ckpt_path = os.path.join(model_dir, f"best_model_{tag}.pt")
    ct_path   = os.path.join(model_dir, f"ct_{tag}.joblib")

    if (not os.path.exists(ckpt_path)) or (not os.path.exists(ct_path)):
        raise FileNotFoundError(f"Missing: {ckpt_path} or {ct_path}")

    ckpt = torch.load(ckpt_path, map_location=device)
    input_dim = ckpt["input_dim"]

    model = MixtureStretchedExponentialSurvival(input_dim=input_dim, num_components=num_components).to(device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    try:
        ct = joblib.load(ct_path)
    except Exception as e:
        raise FileNotFoundError(ct_path) from e

    return model, ct

def evaluate_external_for_all_models(MODEL_DIR, EXTERNAL_CSV, time_points, device=DEVICE):
    """Optional: external AUC/C-index for all saved models in MODEL_DIR."""
    if not os.path.exists(EXTERNAL_CSV):
        print(f"‚ÑπÔ∏è External eval skipped (missing): {EXTERNAL_CSV}")
        return None, None

    df_ext = pd.read_csv(EXTERNAL_CSV)
    if 'event' not in df_ext.columns and 'survival' in df_ext.columns:
        df_ext['event'] = df_ext['survival'].astype(int)
    if 'time' not in df_ext.columns and 'fu_date' in df_ext.columns:
        df_ext['time'] = df_ext['fu_date'].astype(np.float32)

    ckpt_paths = glob.glob(os.path.join(MODEL_DIR, "best_model_run*.pt"))
    if not ckpt_paths:
        print(f"‚ö†Ô∏è No models in {MODEL_DIR}")
        return None, None

    pattern = re.compile(r"best_model_run(\d+)_([^\.]+)\.pt")
    ext_rows_auc_all, ext_rows_cidx_all = [], []

    for ckpt_path in sorted(ckpt_paths):
        m = pattern.search(os.path.basename(ckpt_path))
        if not m:
            continue
        run_idx = int(m.group(1))
        label_raw = m.group(2)
        label = label_raw.replace("_", " ")

        try:
            model, ct = load_model_and_ct(MODEL_DIR, run_idx, label, device=device, num_components=2)
        except FileNotFoundError:
            continue

        required_columns = []
        for name, trans, cols in ct.transformers_:
            if cols is None or cols == []:
                continue
            if isinstance(cols, (list, tuple, np.ndarray, pd.Index)):
                required_columns.extend(list(cols))
            else:
                required_columns.append(cols)

        missing = [c for c in required_columns if c not in df_ext.columns]
        if missing:
            continue

        X_ext = ct.transform(df_ext[required_columns])
        X_ext = pd.DataFrame(X_ext).fillna(0).values.astype(np.float32)
        X_ext_tensor = torch.tensor(X_ext, dtype=torch.float32).to(device)

        surv_ext = predict_survival(model, X_ext_tensor, time_points)
        y_ext = df_ext[['time', 'event']].copy()

        auc_ext = calc_auc(surv_ext, y_ext.reset_index(drop=True), time_points)

        # NOTE: This keeps your original "risk = S(t)" direction for C-index.
        # For strict risk direction you may prefer risk = 1 - S(t) or -S(t).
        risk_ext = surv_ext.T
        cidx_ext = [safe_concordance_index(y_ext['time'], risk_ext[:, j], y_ext['event'])
                    for j in range(len(time_points))]

        for j, t in enumerate(time_points):
            ext_rows_auc_all.append({"RunFile": f"run{run_idx:02d}", "Feature Set": label,
                                     "Time (Months)": t, "AUC (External)": auc_ext[t]})
            ext_rows_cidx_all.append({"RunFile": f"run{run_idx:02d}", "Feature Set": label,
                                      "Time (Months)": t, "C-index (External)": cidx_ext[j]})

    return ext_rows_auc_all, ext_rows_cidx_all

# =========================
# KM helper (fixed cutoff support)
# =========================
def make_km_plot(df_time_event_risk: pd.DataFrame,
                title: str,
                out_png: str,
                out_summary_csv: str,
                fixed_cut: float = None):
    """
    df_time_event_risk must have columns: time, event, risk
    If fixed_cut is provided, use it for High/Low split (e.g., internal median applied to external).
    Otherwise, uses median of df['risk'].
    Returns the cutoff used.
    """
    df = df_time_event_risk.copy()
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["time", "event", "risk"])
    if len(df) < 4:
        print(f"‚ö†Ô∏è Too few samples for KM: {out_png}")
        return None

    cut = float(fixed_cut) if fixed_cut is not None else float(np.nanmedian(df["risk"].values))
    df["risk_group"] = np.where(df["risk"] >= cut, "High", "Low")

    summary = (df.groupby("risk_group")
                 .agg(n=("risk_group", "size"),
                      events=("event", "sum"),
                      median_risk=("risk", "median"))
                 .reset_index())
    summary["cut_used"] = cut
    summary.to_csv(out_summary_csv, index=False)

    kmf = KaplanMeierFitter()
    plt.figure(figsize=(7, 5))
    for g in ["Low", "High"]:
        sub = df[df["risk_group"] == g]
        if len(sub) < 2:
            continue
        kmf.fit(sub["time"], event_observed=sub["event"], label=f"{g} (n={len(sub)})")
        kmf.plot(ci_show=True)

    low = df[df["risk_group"] == "Low"]
    high = df[df["risk_group"] == "High"]
    if (len(low) >= 2) and (len(high) >= 2):
        lr = logrank_test(low["time"], high["time"],
                          event_observed_A=low["event"], event_observed_B=high["event"])
        pval = lr.p_value
    else:
        pval = np.nan

    # ‚úÖ cutoff(Ï§ëÏïôÍ∞í risk) ÌëúÏãú: Ï†úÎ™© + figure text Îëò Îã§
    title2 = f"{title}\ncutoff (median risk) = {cut:.4f} | log-rank p = {pval:.3g}"
    plt.title(title2)

    # Í∑∏Î¶º ÏïàÏóêÎèÑ Ìïú Î≤à Îçî ÌëúÏãú(ÏõêÏπò ÏïäÏúºÎ©¥ Ïù¥ Î∏îÎ°ù ÏÇ≠Ï†ú)
    ax = plt.gca()
    ax.text(0.02, 0.02, f"cutoff = {cut:.4f}", transform=ax.transAxes,
            fontsize=10, verticalalignment="bottom")

    plt.xlabel("Time (Months)")
    plt.ylabel("Survival probability")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()

    print(f"‚úÖ KM saved: {out_png}")
    print(f"‚úÖ KM summary saved: {out_summary_csv}")
    print(f"   ‚Ü≥ cutoff (median risk) = {cut:.6f}")
    return cut


# =========================
# Main: train + save models (kept)
# =========================
device = DEVICE

for base_group in BASE_GROUPS:
    print(f"\n\n============================")
    print(f"üìÅ BEiT Í∑∏Î£π Ïã§Ìñâ Ï§ë: {base_group}")
    print(f"============================")

    SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/non_nest/{base_group}/results/generalization/test_km/dl0/{GROUP}"
    MODEL_DIR_ROOT = f"./survival_model/mixture_non_fix/models/{base_group}/{GROUP}"
    os.makedirs(SAVE_ROOT_BASE, exist_ok=True)
    os.makedirs(MODEL_DIR_ROOT, exist_ok=True)

    img_cols = ["feat_436", "feat_519"]
    cont_cols = ["Age"]
    cat_cols  = ["pathology", "stage0"]

    feature_sets = {
        "Image only": (img_cols, []),
        "Clinical only": ([], cont_cols + cat_cols),
        "Image + Clinical": (img_cols, cont_cols + cat_cols),
    }

    # run only file04 (as your current setting)
    for i in range(4, 5):
        ONLY_RUN_IDX = i
        EXTERNAL_CSV = f"./external/external{ONLY_RUN_IDX}.csv"

        fname = f"dh11_run{ONLY_RUN_IDX:02d}.csv"
        csv_path = f"./deephit/{base_group}/test/dl0/{GROUP}/{fname}"
        print(f"\nüöÄ Training: {base_group} - {fname} | EXTERNAL: external{ONLY_RUN_IDX}.csv")

        MODEL_DIR_I = os.path.join(MODEL_DIR_ROOT, f"file{ONLY_RUN_IDX:02d}")
        SAVE_ROOT   = os.path.join(SAVE_ROOT_BASE, f"file{ONLY_RUN_IDX:02d}")
        os.makedirs(MODEL_DIR_I, exist_ok=True)
        os.makedirs(SAVE_ROOT, exist_ok=True)

        if not os.path.exists(csv_path):
            print(f"‚ö†Ô∏è Missing internal data: {csv_path}")
            continue

        df_all = pd.read_csv(csv_path)
        if 'event' not in df_all.columns and 'survival' in df_all.columns:
            df_all['event'] = df_all['survival'].astype(int)
        if 'time' not in df_all.columns and 'fu_date' in df_all.columns:
            df_all['time']  = df_all['fu_date'].astype(np.float32)

        results_dict = {}
        raw_rows_auc, raw_rows_cidx = [], []

        for label, (img_part, clinical_part) in feature_sets.items():
            print(f"\nüìå Feature Set: {label}")

            auc_train_list, auc_val_list = [], []
            cidx_train_list, cidx_val_list = [], []

            for run in range(N_RUNS):
                set_seed(BASE_SEED + run)

                used_cols = img_part + clinical_part
                X_df = df_all[used_cols].copy()
                y_df = df_all[["time", "event"]].copy()

                X_train_df, X_val_df, y_train, y_val = train_test_split(
                    X_df, y_df, test_size=0.3, random_state=BASE_SEED + run
                )

                transformers = []
                if img_part:
                    transformers.append(("img", StandardScaler(), img_part))

                cont = [c for c in clinical_part if c in cont_cols]
                cat  = [c for c in clinical_part if c in cat_cols]
                if cont:
                    transformers.append(("cont", StandardScaler(), cont))
                if cat:
                    transformers.append(("cat", OneHotEncoder(sparse_output=False, handle_unknown="ignore"), cat))

                ct = ColumnTransformer(transformers)

                X_train = ct.fit_transform(X_train_df)
                X_val   = ct.transform(X_val_df)

                X_train = pd.DataFrame(X_train).fillna(0).values.astype(np.float32)
                X_val   = pd.DataFrame(X_val).fillna(0).values.astype(np.float32)

                X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
                X_val_tensor   = torch.tensor(X_val, dtype=torch.float32).to(device)

                t_train = torch.tensor(y_train["time"].values, dtype=torch.float32).to(device)
                e_train = torch.tensor(y_train["event"].values, dtype=torch.float32).to(device)

                model = MixtureStretchedExponentialSurvival(input_dim=X_train.shape[1], num_components=2).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

                best_loss = float("inf")
                patience, patience_counter = 10, 0
                best_model_state = None

                for epoch in range(1000):
                    model.train()
                    optimizer.zero_grad()
                    pi, lam, a = model(X_train_tensor)
                    loss = mixture_stretched_nll(t_train, e_train, pi, lam, a)
                    loss.backward()
                    optimizer.step()

                    # early stop on train loss (as original)
                    if loss.item() < best_loss - 1e-6:
                        best_loss = loss.item()
                        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                        patience_counter = 0
                    else:
                        patience_counter += 1
                        if patience_counter >= patience:
                            break

                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                    save_model_and_ct(best_model_state, ct, MODEL_DIR_I, run, label, input_dim=X_train.shape[1])

                # internal eval (kept)
                surv_train = predict_survival(model, X_train_tensor, time_points)
                surv_val   = predict_survival(model, X_val_tensor, time_points)

                auc_train = calc_auc(surv_train, y_train.reset_index(drop=True), time_points)
                auc_val   = calc_auc(surv_val,   y_val.reset_index(drop=True), time_points)
                auc_train_list.append(auc_train)
                auc_val_list.append(auc_val)

                # NOTE: keeps your original direction for C-index evaluation
                risk_train = surv_train.T
                risk_val   = surv_val.T
                cidx_train = [safe_concordance_index(y_train["time"], risk_train[:, j], y_train["event"])
                              for j in range(len(time_points))]
                cidx_val   = [safe_concordance_index(y_val["time"],   risk_val[:,   j], y_val["event"])
                              for j in range(len(time_points))]
                cidx_train_list.append(cidx_train)
                cidx_val_list.append(cidx_val)

                for j, t in enumerate(time_points):
                    raw_rows_auc.append({
                        "Feature Set": label, "Run": run, "Time (Months)": t,
                        "AUC (Train)": auc_train[t], "AUC (Val)": auc_val[t], "Scope": "Time-wise"
                    })
                    raw_rows_cidx.append({
                        "Feature Set": label, "Run": run, "Time (Months)": t,
                        "C-index (Train)": cidx_train[j], "C-index (Val)": cidx_val[j], "Scope": "Time-wise"
                    })

            results_dict[label] = {
                "mean_auc_train": {t: np.nanmean([r[t] for r in auc_train_list]) for t in time_points},
                "mean_auc_val":   {t: np.nanmean([r[t] for r in auc_val_list])   for t in time_points},
                "std_auc_train":  {t: np.nanstd([r[t] for r in auc_train_list])  for t in time_points},
                "std_auc_val":    {t: np.nanstd([r[t] for r in auc_val_list])    for t in time_points},
            }

        # save internal eval csv (kept)
        raw_auc_path  = os.path.join(SAVE_ROOT, f"raw_auc_per_time_run{ONLY_RUN_IDX:02d}.csv")
        raw_cidx_path = os.path.join(SAVE_ROOT, f"raw_cindex_per_time_run{ONLY_RUN_IDX:02d}.csv")
        pd.DataFrame(raw_rows_auc).to_csv(raw_auc_path, index=False)
        pd.DataFrame(raw_rows_cidx).to_csv(raw_cidx_path, index=False)
        print(f"‚úÖ Internal eval saved: run{ONLY_RUN_IDX:02d}")

        # optional external eval (kept)
        ext_auc_rows, ext_cidx_rows = evaluate_external_for_all_models(
            MODEL_DIR=MODEL_DIR_I,
            EXTERNAL_CSV=EXTERNAL_CSV,
            time_points=time_points,
            device=DEVICE,
        )
        if ext_auc_rows:
            pd.DataFrame(ext_auc_rows).to_csv(
                os.path.join(SAVE_ROOT, f"external_auc_ALL_runs_from_file{ONLY_RUN_IDX:02d}.csv"),
                index=False
            )
        if ext_cidx_rows:
            pd.DataFrame(ext_cidx_rows).to_csv(
                os.path.join(SAVE_ROOT, f"external_cindex_ALL_runs_from_file{ONLY_RUN_IDX:02d}.csv"),
                index=False
            )

# =========================
# KM ONLY: representative model (file04, run06) and KM plots
# =========================
REP_FILE_IDX = 4
REP_RUN_IDX  = 6
REP_FEATURES = ["Image only", "Clinical only", "Image + Clinical"]
T0 = 60

base_group = BASE_GROUPS[0]

MODEL_DIR_ROOT = f"./survival_model/mixture_non_fix/models/{base_group}/{GROUP}"
MODEL_DIR_I    = os.path.join(MODEL_DIR_ROOT, f"file{REP_FILE_IDX:02d}")

DATA_CSV_PATH      = f"./deephit/{base_group}/test/dl0/{GROUP}/dh11_run{REP_FILE_IDX:02d}.csv"
EXTERNAL_CSV_PATH  = f"./external/external{REP_FILE_IDX}.csv"

SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/non_nest/{base_group}/results/generalization/test_km/dl0/{GROUP}"
SAVE_ROOT_I    = os.path.join(SAVE_ROOT_BASE, f"file{REP_FILE_IDX:02d}")
os.makedirs(SAVE_ROOT_I, exist_ok=True)

print("\n============================================")
print(f"üìå KM ONLY: FILE run{REP_FILE_IDX:02d}, MODEL run{REP_RUN_IDX:02d}, T0={T0}")
print("============================================")

# internal data
if not os.path.exists(DATA_CSV_PATH):
    raise FileNotFoundError(f"Missing internal data: {DATA_CSV_PATH}")
df_all = pd.read_csv(DATA_CSV_PATH)

if 'event' not in df_all.columns and 'survival' in df_all.columns:
    df_all['event'] = df_all['survival'].astype(int)
if 'time' not in df_all.columns and 'fu_date' in df_all.columns:
    df_all['time'] = df_all['fu_date'].astype(np.float32)

# external data (optional)
df_ext = None
if os.path.exists(EXTERNAL_CSV_PATH):
    df_ext = pd.read_csv(EXTERNAL_CSV_PATH)
    if 'event' not in df_ext.columns and 'survival' in df_ext.columns:
        df_ext['event'] = df_ext['survival'].astype(int)
    if 'time' not in df_ext.columns and 'fu_date' in df_ext.columns:
        df_ext['time'] = df_ext['fu_date'].astype(np.float32)

# feature definitions
img_cols = ["feat_436", "feat_519"]
cont_cols = ["Age"]
cat_cols  = ["pathology", "stage0"]

for REP_FEATURE in REP_FEATURES:
    print(f"\n‚û°Ô∏è KM: Feature Set = {REP_FEATURE}")

    if REP_FEATURE == "Image only":
        used_cols = img_cols
    elif REP_FEATURE == "Clinical only":
        used_cols = cont_cols + cat_cols
    elif REP_FEATURE == "Image + Clinical":
        used_cols = img_cols + cont_cols + cat_cols
    else:
        continue

    # reproduce same split as run06
    X_df = df_all[used_cols].copy()
    y_df = df_all[["time", "event"]].copy()

    set_seed(BASE_SEED + REP_RUN_IDX)
    X_train_df, X_val_df, y_train, y_val = train_test_split(
        X_df, y_df, test_size=0.3, random_state=BASE_SEED + REP_RUN_IDX
    )

    # load representative model + ct
    try:
        model, ct = load_model_and_ct(
            MODEL_DIR_I, REP_RUN_IDX, REP_FEATURE, device=DEVICE, num_components=2
        )
    except FileNotFoundError:
        print(f"‚ö†Ô∏è Missing model/ct: {MODEL_DIR_I} run{REP_RUN_IDX:02d} {REP_FEATURE}")
        continue

    # -------- Internal validation: risk = 1 - S(T0)
    X_val = ct.transform(X_val_df)
    X_val = pd.DataFrame(X_val).fillna(0).values.astype(np.float32)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)

    surv_val_T0 = predict_survival(model, X_val_tensor, [T0])[0, :]
    risk_val = 1.0 - surv_val_T0

    df_km_internal = pd.DataFrame({
        "time": y_val["time"].values,
        "event": y_val["event"].values.astype(int),
        "risk": risk_val,
    })

    out_png_int = os.path.join(
        SAVE_ROOT_I,
        f"km_internal_file{REP_FILE_IDX:02d}_run{REP_RUN_IDX:02d}_{REP_FEATURE.replace(' ', '_')}_T{T0}.png"
    )
    out_sum_int = os.path.join(
        SAVE_ROOT_I,
        f"km_internal_summary_file{REP_FILE_IDX:02d}_run{REP_RUN_IDX:02d}_{REP_FEATURE.replace(' ', '_')}_T{T0}.csv"
    )

    # internal median cutoff (store it)
    internal_cut = make_km_plot(
        df_km_internal,
        title=f"KM (Internal Validation) by Median Predicted Risk @ {T0} mo\n{REP_FEATURE} | file{REP_FILE_IDX:02d} run{REP_RUN_IDX:02d}",
        out_png=out_png_int,
        out_summary_csv=out_sum_int,
        fixed_cut=None
    )

    # -------- External: apply INTERNAL cutoff
    if df_ext is None:
        print("   ‚ÑπÔ∏è External missing ‚Üí skip external KM")
        continue

    missing = [c for c in used_cols if c not in df_ext.columns]
    if missing:
        print(f"   ‚ö†Ô∏è External missing columns for {REP_FEATURE}: {missing} ‚Üí skip external KM")
        continue

    X_ext_df = df_ext[used_cols].copy()
    X_ext = ct.transform(X_ext_df)
    X_ext = pd.DataFrame(X_ext).fillna(0).values.astype(np.float32)
    X_ext_tensor = torch.tensor(X_ext, dtype=torch.float32).to(DEVICE)

    surv_ext_T0 = predict_survival(model, X_ext_tensor, [T0])[0, :]
    risk_ext = 1.0 - surv_ext_T0

    df_km_external = pd.DataFrame({
        "time": df_ext["time"].values,
        "event": df_ext["event"].values.astype(int),
        "risk": risk_ext,
    })

    out_png_ext = os.path.join(
        SAVE_ROOT_I,
        f"km_external_file{REP_FILE_IDX:02d}_run{REP_RUN_IDX:02d}_{REP_FEATURE.replace(' ', '_')}_T{T0}.png"
    )
    out_sum_ext = os.path.join(
        SAVE_ROOT_I,
        f"km_external_summary_file{REP_FILE_IDX:02d}_run{REP_RUN_IDX:02d}_{REP_FEATURE.replace(' ', '_')}_T{T0}.csv"
    )

    make_km_plot(
        df_km_external,
        title=f"KM (External) by INTERNAL Cutoff @ {T0} mo\n{REP_FEATURE} | file{REP_FILE_IDX:02d} run{REP_RUN_IDX:02d}",
        out_png=out_png_ext,
        out_summary_csv=out_sum_ext,
        fixed_cut=internal_cut  # ‚úÖ internal median applied to external
    )
