In [None]:
# Mixture Stretched Exponential Survival with Debugging, NaN Handling, Model Saving, and External Evaluation

import os
import joblib
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
import glob, re
# =========================
# Config (상단 설정 일부만 수정)
# =========================
BASE_GROUPS = ["beit0"]
GROUP = "n7_30_30"
N_RUNS = 30
time_points = [12, 24, 36, 48, 60, 72]
DEVICE = torch.device('cpu')

# ✅ 전역 재현성 고정용 기본 시드
BASE_SEED = 20250903

# 저장 디렉토리
os.makedirs("./survival_model/mixture_non_fix/models", exist_ok=True)

# =========================
# Utils
# =========================

def evaluate_external_for_all_models(MODEL_DIR, EXTERNAL_CSV, time_points, device=DEVICE):
    if not os.path.exists(EXTERNAL_CSV):
        print(f"ℹ️ 외부 평가 스킵 (파일 없음): {EXTERNAL_CSV}")
        return None, None

    df_ext = pd.read_csv(EXTERNAL_CSV)

    # 외부 데이터 event/time 컬럼 보정
    if 'event' not in df_ext.columns and 'survival' in df_ext.columns:
        df_ext['event'] = df_ext['survival'].astype(int)
    if 'time' not in df_ext.columns and 'fu_date' in df_ext.columns:
        df_ext['time'] = df_ext['fu_date'].astype(np.float32)

    # MODEL_DIR 내 모든 체크포인트 스캔 (예: best_model_run13_Image_only.pt)
    ckpt_paths = glob.glob(os.path.join(MODEL_DIR, "best_model_run*.pt"))
    if not ckpt_paths:
        print(f"⚠️ {MODEL_DIR} 에 저장된 모델이 없습니다.")
        return None, None

    pattern = re.compile(r"best_model_run(\d+)_([^\.]+)\.pt")
    ext_rows_auc_all, ext_rows_cidx_all = [], []

    for ckpt_path in sorted(ckpt_paths):
        m = pattern.search(os.path.basename(ckpt_path))
        if not m:
            print(f"⚠️ 스킵(파일명 파싱 실패): {ckpt_path}")
            continue
        run_idx = int(m.group(1))
        label_raw = m.group(2)                  # e.g., Image_only
        label = label_raw.replace('_', ' ')     # "Image only"

        # ColumnTransformer 경로
        ct_path = os.path.join(MODEL_DIR, f"ct_run{run_idx:02d}_{label_raw}.joblib")
        if not os.path.exists(ct_path):
            print(f"⚠️ 전처리기 누락 → 스킵: {ct_path}")
            continue

        # 모델/CT 로드
        try:
            model, ct = load_model_and_ct(MODEL_DIR, run_idx, label, device=device, num_components=2)
        except FileNotFoundError:
            print(f"⚠️ 로드 실패 → 스킵: run{run_idx:02d}, {label}")
            continue

        # ColumnTransformer가 기대하는 **원본 입력 컬럼** 모으기
        required_columns = []
        for name, trans, cols in ct.transformers_:
            if cols is None or cols == []:
                continue
            if isinstance(cols, (list, tuple, np.ndarray, pd.Index)):
                required_columns.extend(list(cols))
            else:
                required_columns.append(cols)

        missing = [c for c in required_columns if c not in df_ext.columns]
        if missing:
            print(f"⚠️ 외부 데이터에 '{label}' 입력 컬럼 누락: {missing} → 스킵(run{run_idx:02d})")
            continue

        # 변환 및 예측
        X_ext = ct.transform(df_ext[required_columns])
        X_ext = pd.DataFrame(X_ext).fillna(0).values.astype(np.float32)
        X_ext_tensor = torch.tensor(X_ext, dtype=torch.float32).to(device)

        surv_ext = predict_survival(model, X_ext_tensor, time_points)  # (T, N)
        y_ext = df_ext[['time', 'event']].copy()

        # AUC / C-index
        auc_ext = calc_auc(surv_ext, y_ext.reset_index(drop=True), time_points)
        risk_ext = surv_ext.T
        cidx_ext = [safe_concordance_index(y_ext['time'], risk_ext[:, j], y_ext['event'])
                    for j in range(len(time_points))]

        # 누적 저장
        for j, t in enumerate(time_points):
            ext_rows_auc_all.append({"RunFile": f"run{run_idx:02d}", "Feature Set": label,
                                     "Time (Months)": t, "AUC (External)": auc_ext[t]})
            ext_rows_cidx_all.append({"RunFile": f"run{run_idx:02d}", "Feature Set": label,
                                      "Time (Months)": t, "C-index (External)": cidx_ext[j]})
        ext_rows_auc_all.append({"RunFile": f"run{run_idx:02d}", "Feature Set": label,
                                 "Time (Months)": "Overall",
                                 "AUC (External)": np.nanmean(list(auc_ext.values()))})
        ext_rows_cidx_all.append({"RunFile": f"run{run_idx:02d}", "Feature Set": label,
                                  "Time (Months)": "Overall",
                                  "C-index (External)": np.nanmean(cidx_ext)})

    return ext_rows_auc_all, ext_rows_cidx_all

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

def mixture_stretched_nll(t, e, pi, lam, a, eps=1e-8):
    t = t.view(-1, 1)
    t_a = torch.pow(t + eps, a)
    S_k = torch.exp(-lam * t_a)
    f_k = lam * a * torch.pow(t + eps, a - 1) * S_k
    f = torch.sum(pi * f_k, dim=1) + eps
    S = torch.sum(pi * S_k, dim=1) + eps
    loglik = e * torch.log(f) + (1 - e) * torch.log(S)
    return -loglik.mean()

@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S = torch.sum(pi * S_k, dim=1)
        surv.append(S.cpu().numpy())
    return np.vstack(surv)  # shape: (len(times), N)

def calc_auc(surv_arr, y_df, times):
    aucs = {}
    for i, t in enumerate(times):
        true = ((y_df["event"] == 1) & (y_df["time"] <= t)).astype(int)
        pred = 1 - surv_arr[i, :]
        try:
            aucs[t] = roc_auc_score(true, pred)
        except Exception:
            aucs[t] = np.nan
    return aucs

def safe_concordance_index(times, risks, events):
    times = np.asarray(times)
    risks = np.asarray(risks)
    events = np.asarray(events)
    mask = ~(np.isnan(times) | np.isnan(risks) | np.isnan(events))
    if np.sum(mask) < 2:
        print("⚠️ Too few valid samples for C-index:", np.sum(mask))
        return np.nan
    if np.std(risks[mask]) < 1e-6:
        print("⚠️ Low risk variance, skipping C-index")
        return np.nan
    return concordance_index(times[mask], risks[mask], events[mask])

def save_model_and_ct(model_state, ct, save_dir, run_idx, label, input_dim):
    tag = f"run{run_idx:02d}_{label.replace(' ', '_')}"
    os.makedirs(save_dir, exist_ok=True)
    torch.save({
        "state_dict": model_state,
        "input_dim": input_dim
    }, os.path.join(save_dir, f"best_model_{tag}.pt"))
    joblib.dump(ct, os.path.join(save_dir, f"ct_{tag}.joblib"))

def load_model_and_ct(model_dir, run_idx, label, device=DEVICE, num_components=2):
    tag = f"run{run_idx:02d}_{label.replace(' ', '_')}"
    ckpt = torch.load(os.path.join(model_dir, f"best_model_{tag}.pt"), map_location=device)
    input_dim = ckpt["input_dim"]
    model = MixtureStretchedExponentialSurvival(input_dim=input_dim, num_components=num_components).to(device)
    model.load_state_dict(ckpt["state_dict"])
    ct = joblib.load(os.path.join(model_dir, f"ct_{tag}.joblib"))
    model.eval()
    return model, ct

def evaluate_on_dataframe(model, ct, df, feature_cols, time_points, device=DEVICE):
    X_df = df[feature_cols].copy()
    X = ct.transform(X_df)
    X = pd.DataFrame(X).fillna(0).values.astype(np.float32)
    X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    surv = predict_survival(model, X_tensor, time_points)  # (T, N)
    return surv
# =========================
# Main (이 블록 전체 교체)
# =========================
device = DEVICE

for base_group in BASE_GROUPS:
    print(f"\n\n============================")
    print(f"📁 BEiT 그룹 실행 중: {base_group}")
    print(f"============================")

    SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/tune/{base_group}/results/generalization/test0/dl0/{GROUP}"
    MODEL_DIR_ROOT = f"./survival_model/mixture_non_fix/models/{base_group}/{GROUP}"
    os.makedirs(SAVE_ROOT_BASE, exist_ok=True)
    os.makedirs(MODEL_DIR_ROOT, exist_ok=True)

    # 공통 컬럼 설정
    img_cols = ["feat_436", "feat_519"]
    cont_cols = ['Age', 'stage0']
    cat_cols  = ['pathology']

    feature_sets = {
        'Image only': (img_cols, []),
        'Clinical only': ([], cont_cols + cat_cols),
        'Image + Clinical': (img_cols, cont_cols + cat_cols)
    }

    # ✅ ONLY_RUN_IDX = 1..30 & EXTERNAL_CSV = ./external/external{idx}.csv 로 반복
    for i in range(1, 31):
        ONLY_RUN_IDX = i
        EXTERNAL_CSV = f"./external/external{ONLY_RUN_IDX}.csv"

        fname = f"dh11_run{ONLY_RUN_IDX:02d}.csv"
        csv_path = f"./deephit/{base_group}/test/dl0/{GROUP}/{fname}"
        print(f"\n🚀 실행 중: {base_group} - {fname} | EXTERNAL: external{ONLY_RUN_IDX}.csv")

        # ✅ i별 폴더
        MODEL_DIR_I = os.path.join(MODEL_DIR_ROOT, f"file{ONLY_RUN_IDX:02d}")
        SAVE_ROOT   = os.path.join(SAVE_ROOT_BASE, f"file{ONLY_RUN_IDX:02d}")
        os.makedirs(MODEL_DIR_I, exist_ok=True)
        os.makedirs(SAVE_ROOT, exist_ok=True)

        if not os.path.exists(csv_path):
            print(f"⚠️ 내부 데이터 없음 → 스킵: {csv_path}")
            continue

        df_all = pd.read_csv(csv_path)
        # event/time 열 매핑
        if 'event' not in df_all.columns and 'survival' in df_all.columns:
            df_all['event'] = df_all['survival'].astype(int)
        if 'time' not in df_all.columns and 'fu_date' in df_all.columns:
            df_all['time']  = df_all['fu_date'].astype(np.float32)

        results_dict = {}
        raw_rows_auc, raw_rows_cidx = [], []

        for label, (img_part, clinical_part) in feature_sets.items():
            print(f"\n📌 Feature Set: {label}")
            auc_train_list, auc_val_list = [], []
            cidx_train_list, cidx_val_list = [], []

            # ----- Monte Carlo N_RUNS -----
            for run in range(N_RUNS):
                # ✅ 시드 고정: 전역 기본 시드 + run (항상 동일 재현)
                set_seed(BASE_SEED + run)

                # Design X / y
                used_cols = img_part + clinical_part
                X_df = df_all[used_cols].copy()
                y_df = df_all[['time', 'event']].copy()

                X_train_df, X_val_df, y_train, y_val = train_test_split(
                    X_df, y_df, test_size=0.3, random_state=BASE_SEED + run
                )

                transformers = []
                if img_part:
                    transformers.append(('img', StandardScaler(), img_part))
                cont = [c for c in clinical_part if c in cont_cols]
                cat  = [c for c in clinical_part if c in cat_cols]
                if cont:
                    transformers.append(('cont', StandardScaler(), cont))
                if cat:
                    transformers.append(('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat))

                ct = ColumnTransformer(transformers)
                X_train = ct.fit_transform(X_train_df)
                X_val   = ct.transform(X_val_df)

                X_train = pd.DataFrame(X_train).fillna(0).values.astype(np.float32)
                X_val   = pd.DataFrame(X_val).fillna(0).values.astype(np.float32)

                X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
                X_val_tensor   = torch.tensor(X_val,   dtype=torch.float32).to(device)
                t_train = torch.tensor(y_train['time'].values,  dtype=torch.float32).to(device)
                e_train = torch.tensor(y_train['event'].values, dtype=torch.float32).to(device)

                model = MixtureStretchedExponentialSurvival(
                    input_dim=X_train.shape[1], num_components=2
                ).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

                best_val_loss = float('inf')
                patience, patience_counter = 10, 0
                best_model_state = None

                for epoch in range(1000):
                    model.train()
                    optimizer.zero_grad()
                    pi, lam, a = model(X_train_tensor)
                    loss = mixture_stretched_nll(t_train, e_train, pi, lam, a)
                    loss.backward()
                    optimizer.step()

                    # 간단 Early Stopping (train loss 기준, 원 코드 유지)
                    if loss.item() < best_val_loss - 1e-6:
                        best_val_loss = loss.item()
                        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                        patience_counter = 0
                    else:
                        patience_counter += 1
                        if patience_counter >= patience:
                            break

                # 베스트 모델 로드 및 저장
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                    save_model_and_ct(best_model_state, ct, MODEL_DIR_I, run, label, input_dim=X_train.shape[1])

                # 내부 평가
                surv_train = predict_survival(model, X_train_tensor, time_points)
                surv_val   = predict_survival(model, X_val_tensor,   time_points)

                auc_train = calc_auc(surv_train, y_train.reset_index(drop=True), time_points)
                auc_val   = calc_auc(surv_val,   y_val.reset_index(drop=True),   time_points)
                auc_train_list.append(auc_train)
                auc_val_list.append(auc_val)

                risk_train = surv_train.T
                risk_val   = surv_val.T
                cidx_train = [safe_concordance_index(y_train['time'], risk_train[:, j], y_train['event']) for j in range(len(time_points))]
                cidx_val   = [safe_concordance_index(y_val['time'],   risk_val[:,   j], y_val['event'])   for j in range(len(time_points))]
                cidx_train_list.append(cidx_train)
                cidx_val_list.append(cidx_val)

                # RAW 저장용 누적
                for j, t in enumerate(time_points):
                    raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": t,
                                         "AUC (Train)": auc_train[t], "AUC (Val)": auc_val[t], "Scope": "Time-wise"})
                    raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": t,
                                          "C-index (Train)": cidx_train[j], "C-index (Val)": cidx_val[j], "Scope": "Time-wise"})
                raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall",
                                     "AUC (Train)": np.nanmean(list(auc_train.values())),
                                     "AUC (Val)":   np.nanmean(list(auc_val.values())), "Scope": "Overall"})
                raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall",
                                      "C-index (Train)": np.nanmean(cidx_train),
                                      "C-index (Val)":   np.nanmean(cidx_val), "Scope": "Overall"})

            # 요약 통계
            results_dict[label] = {
                'mean_auc_train': {t: np.nanmean([r[t] for r in auc_train_list]) for t in time_points},
                'mean_auc_val':   {t: np.nanmean([r[t] for r in auc_val_list])   for t in time_points},
                'std_auc_train':  {t: np.nanstd([r[t] for r in auc_train_list])  for t in time_points},
                'std_auc_val':    {t: np.nanstd([r[t] for r in auc_val_list])    for t in time_points},
                'mean_cidx_train':{t: np.nanmean([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'mean_cidx_val':  {t: np.nanmean([r[j] for r in cidx_val_list])   for j, t in enumerate(time_points)},
                'std_cidx_train': {t: np.nanstd([r[j] for r in cidx_train_list])  for j, t in enumerate(time_points)},
                'std_cidx_val':   {t: np.nanstd([r[j] for r in cidx_val_list])    for j, t in enumerate(time_points)}
            }

        # 결과 CSV 저장
        raw_auc_path  = os.path.join(SAVE_ROOT, f"raw_auc_per_time_run{ONLY_RUN_IDX:02d}.csv")
        raw_cidx_path = os.path.join(SAVE_ROOT, f"raw_cindex_per_time_run{ONLY_RUN_IDX:02d}.csv")
        pd.DataFrame(raw_rows_auc).to_csv(raw_auc_path, index=False)
        pd.DataFrame(raw_rows_cidx).to_csv(raw_cidx_path, index=False)
        print(f"✅ 내부 평가 저장 완료: run{ONLY_RUN_IDX:02d}")

        # 시각화 저장 (AUC)
        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_train'].values()),
                         yerr=list(results_dict[label]['std_auc_train'].values()),
                         fmt='--o', capsize=4, label=f"{label} - AUC Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_val'].values()),
                         yerr=list(results_dict[label]['std_auc_val'].values()),
                         fmt='-o', capsize=4, label=f"{label} - AUC Val")
        plt.title(f"AUC (Run {ONLY_RUN_IDX:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("AUC")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_auc_run{ONLY_RUN_IDX:02d}.png"))
        plt.close()

        # 시각화 저장 (C-index)
        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_train'].values()),
                         yerr=list(results_dict[label]['std_cidx_train'].values()),
                         fmt='--s', capsize=4, label=f"{label} - C-index Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_val'].values()),
                         yerr=list(results_dict[label]['std_cidx_val'].values()),
                         fmt='-s', capsize=4, label=f"{label} - C-index Val")
        plt.title(f"C-index (Run {ONLY_RUN_IDX:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("C-index")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_cindex_run{ONLY_RUN_IDX:02d}.png"))
        plt.close()

        # =========================
        # 외부 데이터 평가 (해당 i에서 저장된 모든 run × Feature Set)
        # =========================
        ext_auc_rows, ext_cidx_rows = evaluate_external_for_all_models(
            MODEL_DIR=MODEL_DIR_I,
            EXTERNAL_CSV=EXTERNAL_CSV,
            time_points=time_points,
            device=DEVICE
        )

        if ext_auc_rows:
            ext_auc_path  = os.path.join(SAVE_ROOT, f"external_auc_ALL_runs_from_file{ONLY_RUN_IDX:02d}.csv")
            pd.DataFrame(ext_auc_rows).to_csv(ext_auc_path, index=False)
        if ext_cidx_rows:
            ext_cidx_path = os.path.join(SAVE_ROOT, f"external_cindex_ALL_runs_from_file{ONLY_RUN_IDX:02d}.csv")
            pd.DataFrame(ext_cidx_rows).to_csv(ext_cidx_path, index=False)

        if (ext_auc_rows and len(ext_auc_rows)) or (ext_cidx_rows and len(ext_cidx_rows)):
            print("✅ 외부 평가(모든 저장 모델) 저장 완료")
        else:
            print("ℹ️ 저장된 모델이 없거나, 외부 데이터 컬럼 누락으로 평가가 스킵되었습니다.")


["feat_213", "feat_194", "feat_163", "feat_266", "feat_407", "feat_468", "feat_499", "feat_169", "feat_436", "feat_560", 
                    "feat_2", "feat_327", "feat_391", "feat_519", "feat_173", "feat_181", "feat_389", "feat_715", "feat_107", "feat_203", 
                    "feat_361", "feat_439", "feat_451", "feat_565", "feat_747", "feat_80", "feat_10", "feat_123", "feat_137", "feat_15", 
                    "feat_209", "feat_215", "feat_289", "feat_368", "feat_374", "feat_55", "feat_576", "feat_578", "feat_121", "feat_125", 
                    "feat_143", "feat_223", "feat_240", "feat_25", "feat_309", "feat_498", "feat_514", "feat_554", "feat_577", "feat_617", 
                    "feat_653", "feat_710", "feat_109", "feat_210", "feat_220", "feat_352", "feat_420", "feat_507", "feat_583", "feat_605", 
                    "feat_657", "feat_666", "feat_152", "feat_167", "feat_255", "feat_328", "feat_378", "feat_402", "feat_633", "feat_656"]

["feat_213", "feat_266", "feat_499", "feat_436", "feat_2", "feat_327", "feat_391", "feat_519", "feat_173", "feat_715",
 "feat_107", "feat_80", "feat_137", "feat_209", "feat_215", "feat_374", "feat_55", "feat_223", "feat_554", "feat_577", 
 "feat_109", "feat_583", "feat_657"]


["feat_213", "feat_266", "feat_436", "feat_519", "feat_215"]

["feat_436", "feat_519"]