In [None]:
# Mixture Stretched Exponential Survival with Debugging and NaN Handling
# overlapping feartues across groups

import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
import os

# Seed 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 모델 정의
class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

# Loss 정의
def mixture_stretched_nll(t, e, pi, lam, a, eps=1e-8):
    t = t.view(-1, 1)
    t_a = torch.pow(t + eps, a)
    S_k = torch.exp(-lam * t_a)
    f_k = lam * a * torch.pow(t + eps, a - 1) * S_k
    f = torch.sum(pi * f_k, dim=1) + eps
    S = torch.sum(pi * S_k, dim=1) + eps
    loglik = e * torch.log(f) + (1 - e) * torch.log(S)
    return -loglik.mean()

# 생존곡선 예측
@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S = torch.sum(pi * S_k, dim=1)
        surv.append(S.cpu().numpy())
    return np.vstack(surv)

# AUC 계산
def calc_auc(surv_arr, y_df, times):
    aucs = {}
    for i, t in enumerate(times):
        true = ((y_df["event"] == 1) & (y_df["time"] <= t)).astype(int)
        pred = 1 - surv_arr[i, :]
        try:
            aucs[t] = roc_auc_score(true, pred)
        except:
            aucs[t] = np.nan
    return aucs

# 안정적인 C-index 계산
def safe_concordance_index(times, risks, events):
    times = np.asarray(times)
    risks = np.asarray(risks)
    events = np.asarray(events)
    mask = ~(np.isnan(times) | np.isnan(risks) | np.isnan(events))
    if np.sum(mask) < 2:
        print("⚠️ Too few valid samples for C-index:", np.sum(mask))
        return np.nan
    if np.std(risks[mask]) < 1e-6:
        print("⚠️ Low risk variance, skipping C-index")
        return np.nan
    return concordance_index(times[mask], risks[mask], events[mask])


os.makedirs("./survival_model/mixture_non_fix/models", exist_ok=True)
device = torch.device('cpu')

BASE_GROUPS = ["beit0", "beit0_o", "beit", "beit_o", "beit_resize", "beit_original", "beit1"]
GROUP = "n7_30_30" # n6_30_30, n5_30_30, n4_30_30
N_RUNS = 30
time_points = [12, 24, 36, 48, 60, 72]
DEVICE = torch.device('cpu')

for base_group in BASE_GROUPS:
    print(f"\n\n============================")
    print(f"📁 BEiT 그룹 실행 중: {base_group}")
    print(f"============================")

    SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/tune/{base_group}/results/generalization/test0/dl0/{GROUP}" #test0- ovverlapped features
    os.makedirs(SAVE_ROOT_BASE, exist_ok=True)
    SAVE_ROOT = SAVE_ROOT_BASE

    for i in range(1, N_RUNS + 1):
        fname = f"dh11_run{i:02d}.csv"
        csv_path = f"./dataset/{base_group}/test/dl0/{GROUP}/{fname}"
        print(f"\n🚀 실행 중: {base_group} - {fname}")

    # === 여기에 기존의 코드 전체를 포함시킴 ===
    # 단, 아래 변수들만 이 반복 구조 바깥에서 선언한 것으로 사용하거나, 경로에만 f"{i:02d}" 인덱스를 반영하면 됩니다

        df_all = pd.read_csv(csv_path)
        df_all['event'] = df_all['survival'].astype(bool)
        df_all['time'] = df_all['fu_date'].astype(np.float32)
        img_cols = ["feat_436", "feat_519"] #can be replaced with overlapped feature of n6, n5, n4 
        cont_cols = ['Age', 'stage0']
        cat_cols = ['pathology']

        feature_sets = {
            'Image only': (img_cols, []),
            'Clinical only': ([], cont_cols + cat_cols),
            'Image + Clinical': (img_cols, cont_cols + cat_cols)
        }

        results_dict = {}
        raw_rows_auc, raw_rows_cidx = [], []

        for label, (img_part, clinical_part) in feature_sets.items():
            print(f"\n📌 Feature Set: {label}")
            auc_train_list, auc_val_list = [], []
            cidx_train_list, cidx_val_list = [], []

            for run in range(N_RUNS):
                set_seed(run)

                X_df = df_all[img_part + clinical_part].copy()
                y_df = df_all[['time', 'event']].copy()
                X_train_df, X_val_df, y_train, y_val = train_test_split(X_df, y_df, test_size=0.3, random_state=run)

                transformers = []
                if img_part:
                    transformers.append(('img', StandardScaler(), img_part))
                cont = [c for c in clinical_part if c in cont_cols]
                cat = [c for c in clinical_part if c in cat_cols]
                if cont:
                    transformers.append(('cont', StandardScaler(), cont))
                if cat:
                    transformers.append(('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat))

                ct = ColumnTransformer(transformers)
                X_train = ct.fit_transform(X_train_df)
                X_val = ct.transform(X_val_df)

                X_train = pd.DataFrame(X_train).fillna(0).values.astype(np.float32)
                X_val = pd.DataFrame(X_val).fillna(0).values.astype(np.float32)

                X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
                X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)
                t_train = torch.tensor(y_train['time'].values, dtype=torch.float32).to(DEVICE)
                e_train = torch.tensor(y_train['event'].values, dtype=torch.float32).to(DEVICE)

                model = MixtureStretchedExponentialSurvival(input_dim=X_train.shape[1], num_components=2).to(DEVICE)
                optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

                best_val_loss = float('inf')
                patience, patience_counter = 10, 0
                best_model_state = None

                for epoch in range(1000):
                    model.train()
                    optimizer.zero_grad()
                    pi, lam, a = model(X_train_tensor)
                    loss = mixture_stretched_nll(t_train, e_train, pi, lam, a)
                    loss.backward()
                    optimizer.step()
                    if loss.item() < best_val_loss - 1e-6:
                        best_val_loss = loss.item()
                        best_model_state = model.state_dict()
                        patience_counter = 0
                    else:
                        patience_counter += 1
                        if patience_counter >= patience:
                            break

                if best_model_state:
                    model.load_state_dict(best_model_state)
                    
                    # Best model 저장
                    #model_save_path = os.path.join(SAVE_ROOT, f"best_model_run{i:02d}_{label.replace(' ', '_')}.pt")
                    #torch.save(best_model_state, model_save_path)


                surv_train = predict_survival(model, X_train_tensor, time_points)
                surv_val = predict_survival(model, X_val_tensor, time_points)

                auc_train = calc_auc(surv_train, y_train.reset_index(drop=True), time_points)
                auc_val = calc_auc(surv_val, y_val.reset_index(drop=True), time_points)
                auc_train_list.append(auc_train)
                auc_val_list.append(auc_val)

                risk_train = surv_train.T
                risk_val = surv_val.T
                cidx_train = [safe_concordance_index(y_train['time'], risk_train[:, i], y_train['event']) for i in range(len(time_points))]
                cidx_val = [safe_concordance_index(y_val['time'], risk_val[:, i], y_val['event']) for i in range(len(time_points))]
                cidx_train_list.append(cidx_train)
                cidx_val_list.append(cidx_val)

                for j, t in enumerate(time_points):
                    raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": t, "AUC (Train)": auc_train[t], "AUC (Val)": auc_val[t], "Scope": "Time-wise"})
                    raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": t, "C-index (Train)": cidx_train[j], "C-index (Val)": cidx_val[j], "Scope": "Time-wise"})

                raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall", "AUC (Train)": np.nanmean(list(auc_train.values())), "AUC (Val)": np.nanmean(list(auc_val.values())), "Scope": "Overall"})
                raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall", "C-index (Train)": np.nanmean(cidx_train), "C-index (Val)": np.nanmean(cidx_val), "Scope": "Overall"})

            results_dict[label] = {
                'mean_auc_train': {t: np.nanmean([r[t] for r in auc_train_list]) for t in time_points},
                'mean_auc_val': {t: np.nanmean([r[t] for r in auc_val_list]) for t in time_points},
                'std_auc_train': {t: np.nanstd([r[t] for r in auc_train_list]) for t in time_points},
                'std_auc_val': {t: np.nanstd([r[t] for r in auc_val_list]) for t in time_points},
                'mean_cidx_train': {t: np.nanmean([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'mean_cidx_val': {t: np.nanmean([r[j] for r in cidx_val_list]) for j, t in enumerate(time_points)},
                'std_cidx_train': {t: np.nanstd([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'std_cidx_val': {t: np.nanstd([r[j] for r in cidx_val_list]) for j, t in enumerate(time_points)}
            }

        # 결과 저장
        raw_auc_path = os.path.join(SAVE_ROOT, f"raw_auc_per_time_run{i:02d}.csv")
        raw_cidx_path = os.path.join(SAVE_ROOT, f"raw_cindex_per_time_run{i:02d}.csv")
        pd.DataFrame(raw_rows_auc).to_csv(raw_auc_path, index=False)
        pd.DataFrame(raw_rows_cidx).to_csv(raw_cidx_path, index=False)
        print(f"✅ 저장 완료: run{i:02d}")

        # 시각화 저장
        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_train'].values()),
                        yerr=list(results_dict[label]['std_auc_train'].values()), fmt='--o', capsize=4, label=f"{label} - AUC Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_val'].values()),
                        yerr=list(results_dict[label]['std_auc_val'].values()), fmt='-o', capsize=4, label=f"{label} - AUC Val")
        plt.title(f"AUC (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("AUC")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_auc_run{i:02d}.png"))
        plt.close()

        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_train'].values()),
                        yerr=list(results_dict[label]['std_cidx_train'].values()), fmt='--s', capsize=4, label=f"{label} - C-index Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_val'].values()),
                        yerr=list(results_dict[label]['std_cidx_val'].values()), fmt='-s', capsize=4, label=f"{label} - C-index Val")
        plt.title(f"C-index (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("C-index")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_cindex_run{i:02d}.png"))
        plt.close()




📁 BEiT 그룹 실행 중: beit0

🚀 실행 중: beit0 - dh11_run01.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run01

🚀 실행 중: beit0 - dh11_run02.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run02

🚀 실행 중: beit0 - dh11_run03.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run03

🚀 실행 중: beit0 - dh11_run04.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run04

🚀 실행 중: beit0 - dh11_run05.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run05

🚀 실행 중: beit0 - dh11_run06.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run06

🚀 실행 중: beit0 - dh11_run07.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run07

🚀 실행 중: beit0 - 

In [None]:
# Mixture Stretched Exponential Survival with Debugging and NaN Handling
# overlapping feartues across groups - progression (reucur1>0), DM only , LP+DM, local progression only

import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
import os

# Seed 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 모델 정의
class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

# Loss 정의
def mixture_stretched_nll(t, e, pi, lam, a, eps=1e-8):
    t = t.view(-1, 1)
    t_a = torch.pow(t + eps, a)
    S_k = torch.exp(-lam * t_a)
    f_k = lam * a * torch.pow(t + eps, a - 1) * S_k
    f = torch.sum(pi * f_k, dim=1) + eps
    S = torch.sum(pi * S_k, dim=1) + eps
    loglik = e * torch.log(f) + (1 - e) * torch.log(S)
    return -loglik.mean()

# 생존곡선 예측
@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S = torch.sum(pi * S_k, dim=1)
        surv.append(S.cpu().numpy())
    return np.vstack(surv)

# AUC 계산
def calc_auc(surv_arr, y_df, times):
    aucs = {}
    for i, t in enumerate(times):
        true = ((y_df["event"] == 1) & (y_df["time"] <= t)).astype(int)
        pred = 1 - surv_arr[i, :]
        try:
            aucs[t] = roc_auc_score(true, pred)
        except:
            aucs[t] = np.nan
    return aucs

# 안정적인 C-index 계산
def safe_concordance_index(times, risks, events):
    times = np.asarray(times)
    risks = np.asarray(risks)
    events = np.asarray(events)
    mask = ~(np.isnan(times) | np.isnan(risks) | np.isnan(events))
    if np.sum(mask) < 2:
        print("⚠️ Too few valid samples for C-index:", np.sum(mask))
        return np.nan
    if np.std(risks[mask]) < 1e-6:
        print("⚠️ Low risk variance, skipping C-index")
        return np.nan
    return concordance_index(times[mask], risks[mask], events[mask])


os.makedirs("./survival_model/mixture_non_fix/models", exist_ok=True)
device = torch.device('cpu')

BASE_GROUPS = ["beit0"]
GROUP = "n7_30_30" # n6_30_30, n5_30_30, n4_30_30
N_RUNS = 30
time_points = [12, 24, 36, 48, 60, 72]
DEVICE = torch.device('cpu')

for base_group in BASE_GROUPS:
    print(f"\n\n============================")
    print(f"📁 BEiT 그룹 실행 중: {base_group}")
    print(f"============================")

    SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/tune/{base_group}/results/generalization/test2/dl0/{GROUP}" #test1- progression, test2- DM only, test3- LP+DM, test4- local progression
    os.makedirs(SAVE_ROOT_BASE, exist_ok=True)
    SAVE_ROOT = SAVE_ROOT_BASE

    for i in range(1, N_RUNS + 1):
        fname = f"dh11_run{i:02d}.csv"
        csv_path = f"./dataset/{base_group}/test/dl0/{GROUP}/{fname}"
        print(f"\n🚀 실행 중: {base_group} - {fname}")

    # === 여기에 기존의 코드 전체를 포함시킴 ===
    # 단, 아래 변수들만 이 반복 구조 바깥에서 선언한 것으로 사용하거나, 경로에만 f"{i:02d}" 인덱스를 반영하면 됩니다

        df_all = pd.read_csv(csv_path)
        df_all['event'] = df_all['recur1'].isin([2]) # 3 - LP+DM, 4 LP only, 1. progression df_all['event'] = df_all['recur'].astype(bool) 
        df_all['time'] = df_all['recur_date'].astype(np.float32)
        
        img_cols = ["feat_436", "feat_519"] #can be replaced with overlapped feature of n6, n5, n4 
        cont_cols = ['Age', 'stage0']
        cat_cols = ['pathology']

        feature_sets = {
            'Image only': (img_cols, []),
            'Clinical only': ([], cont_cols + cat_cols),
            'Image + Clinical': (img_cols, cont_cols + cat_cols)
        }

        results_dict = {}
        raw_rows_auc, raw_rows_cidx = [], []

        for label, (img_part, clinical_part) in feature_sets.items():
            print(f"\n📌 Feature Set: {label}")
            auc_train_list, auc_val_list = [], []
            cidx_train_list, cidx_val_list = [], []

            for run in range(N_RUNS):
                set_seed(run)

                X_df = df_all[img_part + clinical_part].copy()
                y_df = df_all[['time', 'event']].copy()
                X_train_df, X_val_df, y_train, y_val = train_test_split(X_df, y_df, test_size=0.3, random_state=run)

                transformers = []
                if img_part:
                    transformers.append(('img', StandardScaler(), img_part))
                cont = [c for c in clinical_part if c in cont_cols]
                cat = [c for c in clinical_part if c in cat_cols]
                if cont:
                    transformers.append(('cont', StandardScaler(), cont))
                if cat:
                    transformers.append(('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat))

                ct = ColumnTransformer(transformers)
                X_train = ct.fit_transform(X_train_df)
                X_val = ct.transform(X_val_df)

                X_train = pd.DataFrame(X_train).fillna(0).values.astype(np.float32)
                X_val = pd.DataFrame(X_val).fillna(0).values.astype(np.float32)

                X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
                X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)
                t_train = torch.tensor(y_train['time'].values, dtype=torch.float32).to(DEVICE)
                e_train = torch.tensor(y_train['event'].values, dtype=torch.float32).to(DEVICE)

                model = MixtureStretchedExponentialSurvival(input_dim=X_train.shape[1], num_components=2).to(DEVICE)
                optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

                best_val_loss = float('inf')
                patience, patience_counter = 10, 0
                best_model_state = None

                for epoch in range(1000):
                    model.train()
                    optimizer.zero_grad()
                    pi, lam, a = model(X_train_tensor)
                    loss = mixture_stretched_nll(t_train, e_train, pi, lam, a)
                    loss.backward()
                    optimizer.step()
                    if loss.item() < best_val_loss - 1e-6:
                        best_val_loss = loss.item()
                        best_model_state = model.state_dict()
                        patience_counter = 0
                    else:
                        patience_counter += 1
                        if patience_counter >= patience:
                            break

                if best_model_state:
                    model.load_state_dict(best_model_state)
                    
                    # Best model 저장
                    #model_save_path = os.path.join(SAVE_ROOT, f"best_model_run{i:02d}_{label.replace(' ', '_')}.pt")
                    #torch.save(best_model_state, model_save_path)


                surv_train = predict_survival(model, X_train_tensor, time_points)
                surv_val = predict_survival(model, X_val_tensor, time_points)

                auc_train = calc_auc(surv_train, y_train.reset_index(drop=True), time_points)
                auc_val = calc_auc(surv_val, y_val.reset_index(drop=True), time_points)
                auc_train_list.append(auc_train)
                auc_val_list.append(auc_val)

                risk_train = surv_train.T
                risk_val = surv_val.T
                cidx_train = [safe_concordance_index(y_train['time'], risk_train[:, i], y_train['event']) for i in range(len(time_points))]
                cidx_val = [safe_concordance_index(y_val['time'], risk_val[:, i], y_val['event']) for i in range(len(time_points))]
                cidx_train_list.append(cidx_train)
                cidx_val_list.append(cidx_val)

                for j, t in enumerate(time_points):
                    raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": t, "AUC (Train)": auc_train[t], "AUC (Val)": auc_val[t], "Scope": "Time-wise"})
                    raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": t, "C-index (Train)": cidx_train[j], "C-index (Val)": cidx_val[j], "Scope": "Time-wise"})

                raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall", "AUC (Train)": np.nanmean(list(auc_train.values())), "AUC (Val)": np.nanmean(list(auc_val.values())), "Scope": "Overall"})
                raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall", "C-index (Train)": np.nanmean(cidx_train), "C-index (Val)": np.nanmean(cidx_val), "Scope": "Overall"})

            results_dict[label] = {
                'mean_auc_train': {t: np.nanmean([r[t] for r in auc_train_list]) for t in time_points},
                'mean_auc_val': {t: np.nanmean([r[t] for r in auc_val_list]) for t in time_points},
                'std_auc_train': {t: np.nanstd([r[t] for r in auc_train_list]) for t in time_points},
                'std_auc_val': {t: np.nanstd([r[t] for r in auc_val_list]) for t in time_points},
                'mean_cidx_train': {t: np.nanmean([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'mean_cidx_val': {t: np.nanmean([r[j] for r in cidx_val_list]) for j, t in enumerate(time_points)},
                'std_cidx_train': {t: np.nanstd([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'std_cidx_val': {t: np.nanstd([r[j] for r in cidx_val_list]) for j, t in enumerate(time_points)}
            }

        # 결과 저장
        raw_auc_path = os.path.join(SAVE_ROOT, f"raw_auc_per_time_run{i:02d}.csv")
        raw_cidx_path = os.path.join(SAVE_ROOT, f"raw_cindex_per_time_run{i:02d}.csv")
        pd.DataFrame(raw_rows_auc).to_csv(raw_auc_path, index=False)
        pd.DataFrame(raw_rows_cidx).to_csv(raw_cidx_path, index=False)
        print(f"✅ 저장 완료: run{i:02d}")

        # 시각화 저장
        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_train'].values()),
                        yerr=list(results_dict[label]['std_auc_train'].values()), fmt='--o', capsize=4, label=f"{label} - AUC Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_val'].values()),
                        yerr=list(results_dict[label]['std_auc_val'].values()), fmt='-o', capsize=4, label=f"{label} - AUC Val")
        plt.title(f"AUC (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("AUC")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_auc_run{i:02d}.png"))
        plt.close()

        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_train'].values()),
                        yerr=list(results_dict[label]['std_cidx_train'].values()), fmt='--s', capsize=4, label=f"{label} - C-index Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_val'].values()),
                        yerr=list(results_dict[label]['std_cidx_val'].values()), fmt='-s', capsize=4, label=f"{label} - C-index Val")
        plt.title(f"C-index (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("C-index")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_cindex_run{i:02d}.png"))
        plt.close()


In [None]:
# Mixture Stretched Exponential Survival with Debugging and NaN Handling
# Each group 

import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
import os

# Seed 고정
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# 모델 정의
class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

# Loss 정의
def mixture_stretched_nll(t, e, pi, lam, a, eps=1e-8):
    t = t.view(-1, 1)
    t_a = torch.pow(t + eps, a)
    S_k = torch.exp(-lam * t_a)
    f_k = lam * a * torch.pow(t + eps, a - 1) * S_k
    f = torch.sum(pi * f_k, dim=1) + eps
    S = torch.sum(pi * S_k, dim=1) + eps
    loglik = e * torch.log(f) + (1 - e) * torch.log(S)
    return -loglik.mean()

# 생존곡선 예측
@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S = torch.sum(pi * S_k, dim=1)
        surv.append(S.cpu().numpy())
    return np.vstack(surv)

# AUC 계산
def calc_auc(surv_arr, y_df, times):
    aucs = {}
    for i, t in enumerate(times):
        true = ((y_df["event"] == 1) & (y_df["time"] <= t)).astype(int)
        pred = 1 - surv_arr[i, :]
        try:
            aucs[t] = roc_auc_score(true, pred)
        except:
            aucs[t] = np.nan
    return aucs

# 안정적인 C-index 계산
def safe_concordance_index(times, risks, events):
    times = np.asarray(times)
    risks = np.asarray(risks)
    events = np.asarray(events)
    mask = ~(np.isnan(times) | np.isnan(risks) | np.isnan(events))
    if np.sum(mask) < 2:
        print("⚠️ Too few valid samples for C-index:", np.sum(mask))
        return np.nan
    if np.std(risks[mask]) < 1e-6:
        print("⚠️ Low risk variance, skipping C-index")
        return np.nan
    return concordance_index(times[mask], risks[mask], events[mask])


os.makedirs("./survival_model/mixture_non_fix/models", exist_ok=True)
device = torch.device('cpu')

# 공통 설정
N_RUNS = 30
time_points = [12, 24, 36, 48, 60, 72]
DEVICE = torch.device('cpu')

for group in ["n4_30_30", "n5_30_30", "n6_30_30", "n7_30_30", "n8_30_30", "n9_30_30"]:
    print(f"\n============================")
    print(f"📂 그룹 실행 중: {group}")
    print(f"============================")

    SAVE_ROOT = f"./survival_model/mixture_non_fix/tune/beit0_o/results/generalization/test/dl0/{group}" 
    # test : features of each dataset, beit0_o can be replaced with other groups (beit0, beit, beit_o, beit_resize, beit_original (beit_resize_o))
    os.makedirs(SAVE_ROOT, exist_ok=True)

    # Run 반복 (1~30)
    for i in range(1, 31):
        fname = f"dh11_run{i:02d}.csv"
        csv_path = f"./dataset/beit0_o/test/dl0/{group}/{fname}" #same as above
        print(f"\n🚀 실행 중: {fname}")

    # === 여기에 기존의 코드 전체를 포함시킴 ===
    # 단, 아래 변수들만 이 반복 구조 바깥에서 선언한 것으로 사용하거나, 경로에만 f"{i:02d}" 인덱스를 반영하면 됩니다

        df_all = pd.read_csv(csv_path)
        df_all['event'] = df_all['survival'].astype(bool)
        df_all['time'] = df_all['fu_date'].astype(np.float32)

        img_cols = df_all.columns[22:-2].tolist()
        cont_cols = ['Age', 'stage0']
        cat_cols = ['pathology']

        feature_sets = {
            'Image only': (img_cols, []),
            'Clinical only': ([], cont_cols + cat_cols),
            'Image + Clinical': (img_cols, cont_cols + cat_cols)
        }

        results_dict = {}
        raw_rows_auc, raw_rows_cidx = [], []

        for label, (img_part, clinical_part) in feature_sets.items():
            print(f"\n📌 Feature Set: {label}")
            auc_train_list, auc_val_list = [], []
            cidx_train_list, cidx_val_list = [], []

            for run in range(N_RUNS):
                set_seed(run)

                X_df = df_all[img_part + clinical_part].copy()
                y_df = df_all[['time', 'event']].copy()
                X_train_df, X_val_df, y_train, y_val = train_test_split(X_df, y_df, test_size=0.3, random_state=run)

                transformers = []
                if img_part:
                    transformers.append(('img', StandardScaler(), img_part))
                cont = [c for c in clinical_part if c in cont_cols]
                cat = [c for c in clinical_part if c in cat_cols]
                if cont:
                    transformers.append(('cont', StandardScaler(), cont))
                if cat:
                    transformers.append(('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat))

                ct = ColumnTransformer(transformers)
                X_train = ct.fit_transform(X_train_df)
                X_val = ct.transform(X_val_df)

                X_train = pd.DataFrame(X_train).fillna(0).values.astype(np.float32)
                X_val = pd.DataFrame(X_val).fillna(0).values.astype(np.float32)

                X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
                X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(DEVICE)
                t_train = torch.tensor(y_train['time'].values, dtype=torch.float32).to(DEVICE)
                e_train = torch.tensor(y_train['event'].values, dtype=torch.float32).to(DEVICE)

                model = MixtureStretchedExponentialSurvival(input_dim=X_train.shape[1], num_components=2).to(DEVICE)
                optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

                best_val_loss = float('inf')
                patience, patience_counter = 10, 0
                best_model_state = None

                for epoch in range(1000):
                    model.train()
                    optimizer.zero_grad()
                    pi, lam, a = model(X_train_tensor)
                    loss = mixture_stretched_nll(t_train, e_train, pi, lam, a)
                    loss.backward()
                    optimizer.step()
                    if loss.item() < best_val_loss - 1e-6:
                        best_val_loss = loss.item()
                        best_model_state = model.state_dict()
                        patience_counter = 0
                    else:
                        patience_counter += 1
                        if patience_counter >= patience:
                            break

                if best_model_state:
                    model.load_state_dict(best_model_state)
                    
                    # Best model 저장
                    #model_save_path = os.path.join(SAVE_ROOT, f"best_model_run{i:02d}_{label.replace(' ', '_')}.pt")
                    #torch.save(best_model_state, model_save_path)


                surv_train = predict_survival(model, X_train_tensor, time_points)
                surv_val = predict_survival(model, X_val_tensor, time_points)

                auc_train = calc_auc(surv_train, y_train.reset_index(drop=True), time_points)
                auc_val = calc_auc(surv_val, y_val.reset_index(drop=True), time_points)
                auc_train_list.append(auc_train)
                auc_val_list.append(auc_val)

                risk_train = surv_train.T
                risk_val = surv_val.T
                cidx_train = [safe_concordance_index(y_train['time'], risk_train[:, i], y_train['event']) for i in range(len(time_points))]
                cidx_val = [safe_concordance_index(y_val['time'], risk_val[:, i], y_val['event']) for i in range(len(time_points))]
                cidx_train_list.append(cidx_train)
                cidx_val_list.append(cidx_val)

                for j, t in enumerate(time_points):
                    raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": t, "AUC (Train)": auc_train[t], "AUC (Val)": auc_val[t], "Scope": "Time-wise"})
                    raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": t, "C-index (Train)": cidx_train[j], "C-index (Val)": cidx_val[j], "Scope": "Time-wise"})

                raw_rows_auc.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall", "AUC (Train)": np.nanmean(list(auc_train.values())), "AUC (Val)": np.nanmean(list(auc_val.values())), "Scope": "Overall"})
                raw_rows_cidx.append({"Feature Set": label, "Run": run, "Time (Months)": "Overall", "C-index (Train)": np.nanmean(cidx_train), "C-index (Val)": np.nanmean(cidx_val), "Scope": "Overall"})

            results_dict[label] = {
                'mean_auc_train': {t: np.nanmean([r[t] for r in auc_train_list]) for t in time_points},
                'mean_auc_val': {t: np.nanmean([r[t] for r in auc_val_list]) for t in time_points},
                'std_auc_train': {t: np.nanstd([r[t] for r in auc_train_list]) for t in time_points},
                'std_auc_val': {t: np.nanstd([r[t] for r in auc_val_list]) for t in time_points},
                'mean_cidx_train': {t: np.nanmean([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'mean_cidx_val': {t: np.nanmean([r[j] for r in cidx_val_list]) for j, t in enumerate(time_points)},
                'std_cidx_train': {t: np.nanstd([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'std_cidx_val': {t: np.nanstd([r[j] for r in cidx_val_list]) for j, t in enumerate(time_points)}
            }

        # 결과 저장
        raw_auc_path = os.path.join(SAVE_ROOT, f"raw_auc_per_time_run{i:02d}.csv")
        raw_cidx_path = os.path.join(SAVE_ROOT, f"raw_cindex_per_time_run{i:02d}.csv")
        pd.DataFrame(raw_rows_auc).to_csv(raw_auc_path, index=False)
        pd.DataFrame(raw_rows_cidx).to_csv(raw_cidx_path, index=False)
        print(f"✅ 저장 완료: run{i:02d}")

        # 시각화 저장
        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_train'].values()),
                        yerr=list(results_dict[label]['std_auc_train'].values()), fmt='--o', capsize=4, label=f"{label} - AUC Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_auc_val'].values()),
                        yerr=list(results_dict[label]['std_auc_val'].values()), fmt='-o', capsize=4, label=f"{label} - AUC Val")
        plt.title(f"AUC (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("AUC")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_auc_run{i:02d}.png"))
        plt.close()

        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_train'].values()),
                        yerr=list(results_dict[label]['std_cidx_train'].values()), fmt='--s', capsize=4, label=f"{label} - C-index Train")
            plt.errorbar(time_points, list(results_dict[label]['mean_cidx_val'].values()),
                        yerr=list(results_dict[label]['std_cidx_val'].values()), fmt='-s', capsize=4, label=f"{label} - C-index Val")
        plt.title(f"C-index (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("C-index")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_cindex_run{i:02d}.png"))
        plt.close()



📂 그룹 실행 중: n4_30_30

🚀 실행 중: dh11_run01.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run01

🚀 실행 중: dh11_run02.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run02

🚀 실행 중: dh11_run03.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run03

🚀 실행 중: dh11_run04.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run04

🚀 실행 중: dh11_run05.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run05

🚀 실행 중: dh11_run06.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run06

🚀 실행 중: dh11_run07.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical only

📌 Feature Set: Image + Clinical
✅ 저장 완료: run07

🚀 실행 중: dh11_run08.csv

📌 Feature Set: Image only

📌 Feature Set: Clinical 

In [None]:
# CoxPH pipeline using S(t) directly for lifelines' C-index

import pandas as pd
import numpy as np
import random
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
from lifelines import CoxPHFitter
import matplotlib.pyplot as plt
import os

# -----------------------------
# Utils
# -----------------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def calc_auc(surv_arr, y_df, times):
    """
    surv_arr: shape (len(times), n_samples), values = S(t)
    AUC는 사건 발생 확률이 '높을수록' 양성으로 맞추기 위해 1 - S(t)를 사용
    """
    aucs = {}
    for i, t in enumerate(times):
        true = ((y_df["event"] == 1) & (y_df["time"] <= t)).astype(int)
        pred = 1 - surv_arr[i, :]  # lower survival -> higher event prob
        try:
            aucs[t] = roc_auc_score(true, pred)
        except:
            aucs[t] = np.nan
    return aucs

def safe_concordance_index(times, scores, events):
    """
    lifelines.concordance_index는 '값이 클수록 오래 생존' 점수를 기대.
    여기서는 S(t)를 그대로 점수로 사용.
    """
    times = np.asarray(times)
    scores = np.asarray(scores)
    events = np.asarray(events)
    mask = ~(np.isnan(times) | np.isnan(scores) | np.isnan(events))
    if np.sum(mask) < 2:
        print("⚠️ Too few valid samples for C-index:", np.sum(mask))
        return np.nan
    if np.std(scores[mask]) < 1e-6:
        print("⚠️ Low score variance, skipping C-index")
        return np.nan
    return concordance_index(times[mask], scores[mask], events[mask])

def predict_survival_cox(cph, X_df, times):
    # returns numpy array shape (len(times), n_samples) with survival probabilities S(t)
    surv_df = cph.predict_survival_function(X_df, times=pd.Index(times))
    return surv_df.values

# -----------------------------
# Paths & Const
# -----------------------------
os.makedirs("./survival_model/mixture_non_fix/models", exist_ok=True)
DEVICE = torch.device('cpu')

BASE_GROUPS = ["beit0"]
GROUP = "n7_30_30"
N_RUNS = 30
time_points = [12, 24, 36, 48, 60, 72]

for base_group in BASE_GROUPS:
    print(f"\n\n============================")
    print(f"📁 BEiT 그룹 실행 중: {base_group}")
    print(f"============================")

    SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/tune/{base_group}/results/generalization/test_cox/dl0/{GROUP}"
    os.makedirs(SAVE_ROOT_BASE, exist_ok=True)
    SAVE_ROOT = SAVE_ROOT_BASE

    for i in range(1, N_RUNS + 1):
        fname = f"dh11_run{i:02d}.csv"
        csv_path = f"./dataset/{base_group}/test/dl0/{GROUP}/{fname}"
        print(f"\n🚀 실행 중: {base_group} - {fname}")

        # -----------------------------
        # Load & Columns
        # -----------------------------
        df_all = pd.read_csv(csv_path)
        df_all['event'] = df_all['survival'].astype(bool)
        df_all['time'] = df_all['fu_date'].astype(np.float32)

        img_cols = ["feat_436", "feat_519"]
        cont_cols = ['Age', 'stage0']
        cat_cols  = ['pathology']

        feature_sets = {
            'Image only': (img_cols, []),
            'Clinical only': ([], cont_cols + cat_cols),
            'Image + Clinical': (img_cols, cont_cols + cat_cols)
        }

        results_dict = {}
        raw_rows_auc, raw_rows_cidx = [], []

        # -----------------------------
        # Loops
        # -----------------------------
        for label, (img_part, clinical_part) in feature_sets.items():
            print(f"\n📌 Feature Set: {label}")
            auc_train_list, auc_val_list = [], []
            cidx_train_list, cidx_val_list = [], []

            for run in range(N_RUNS):
                set_seed(run)

                X_df = df_all[img_part + clinical_part].copy()
                y_df = df_all[['time', 'event']].copy()
                X_train_df, X_val_df, y_train, y_val = train_test_split(
                    X_df, y_df, test_size=0.3, random_state=run
                )

                # ----- Preprocess -----
                transformers = []
                if img_part:
                    transformers.append(('img', StandardScaler(), img_part))
                cont = [c for c in clinical_part if c in cont_cols]
                cat  = [c for c in clinical_part if c in cat_cols]
                if cont:
                    transformers.append(('cont', StandardScaler(), cont))
                if cat:
                    transformers.append(('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat))

                ct = ColumnTransformer(transformers)
                X_train_np = ct.fit_transform(X_train_df)
                X_val_np   = ct.transform(X_val_df)

                feat_names = ct.get_feature_names_out()
                X_train_c = (pd.DataFrame(X_train_np, columns=feat_names)
                               .replace([np.inf, -np.inf], np.nan).fillna(0.0))
                X_val_c   = (pd.DataFrame(X_val_np,   columns=feat_names)
                               .replace([np.inf, -np.inf], np.nan).fillna(0.0))

                # ----- CoxPH Fit -----
                train_df_for_cox = X_train_c.copy()
                train_df_for_cox['time']  = y_train['time'].values
                train_df_for_cox['event'] = y_train['event'].astype(int).values

                cph = CoxPHFitter(penalizer=0.1)  # small L2 for stability
                cph.fit(train_df_for_cox, duration_col='time', event_col='event', show_progress=False)

                # ----- Survival prediction -----
                surv_train = predict_survival_cox(cph, X_train_c, time_points)  # shape (T, n_train), S(t)
                surv_val   = predict_survival_cox(cph, X_val_c,   time_points)  # shape (T, n_val),   S(t)

                # ----- AUC -----
                auc_train = calc_auc(surv_train, y_train.reset_index(drop=True), time_points)
                auc_val   = calc_auc(surv_val,   y_val.reset_index(drop=True),   time_points)
                auc_train_list.append(auc_train)
                auc_val_list.append(auc_val)

                # ----- C-index (lifelines expects "higher = longer survival"): use S(t) directly) -----
                score_train = surv_train.T  # (n_train, T), higher S(t) => longer survival
                score_val   = surv_val.T
                cidx_train = [safe_concordance_index(y_train['time'], score_train[:, j], y_train['event'])
                              for j in range(len(time_points))]
                cidx_val   = [safe_concordance_index(y_val['time'],   score_val[:, j],   y_val['event'])
                              for j in range(len(time_points))]
                cidx_train_list.append(cidx_train)
                cidx_val_list.append(cidx_val)

                # ----- Raw rows -----
                for j, t in enumerate(time_points):
                    raw_rows_auc.append({
                        "Feature Set": label, "Run": run, "Time (Months)": t,
                        "AUC (Train)": auc_train[t], "AUC (Val)": auc_val[t], "Scope": "Time-wise"
                    })
                    raw_rows_cidx.append({
                        "Feature Set": label, "Run": run, "Time (Months)": t,
                        "C-index (Train)": cidx_train[j], "C-index (Val)": cidx_val[j], "Scope": "Time-wise"
                    })

                raw_rows_auc.append({
                    "Feature Set": label, "Run": run, "Time (Months)": "Overall",
                    "AUC (Train)": np.nanmean(list(auc_train.values())),
                    "AUC (Val)":   np.nanmean(list(auc_val.values())), "Scope": "Overall"
                })
                raw_rows_cidx.append({
                    "Feature Set": label, "Run": run, "Time (Months)": "Overall",
                    "C-index (Train)": np.nanmean(cidx_train),
                    "C-index (Val)":   np.nanmean(cidx_val), "Scope": "Overall"
                })

            # ----- Aggregate stats -----
            results_dict[label] = {
                'mean_auc_train': {t: np.nanmean([r[t] for r in auc_train_list]) for t in time_points},
                'mean_auc_val':   {t: np.nanmean([r[t] for r in auc_val_list])   for t in time_points},
                'std_auc_train':  {t: np.nanstd([r[t] for r in auc_train_list])  for t in time_points},
                'std_auc_val':    {t: np.nanstd([r[t] for r in auc_val_list])    for t in time_points},
                'mean_cidx_train':{t: np.nanmean([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                'mean_cidx_val':  {t: np.nanmean([r[j] for r in cidx_val_list])   for j, t in enumerate(time_points)},
                'std_cidx_train': {t: np.nanstd([r[j] for r in cidx_train_list])  for j, t in enumerate(time_points)},
                'std_cidx_val':   {t: np.nanstd([r[j] for r in cidx_val_list])    for j, t in enumerate(time_points)}
            }

        # -----------------------------
        # Save raw CSVs
        # -----------------------------
        raw_auc_path  = os.path.join(SAVE_ROOT, f"raw_auc_per_time_run{i:02d}.csv")
        raw_cidx_path = os.path.join(SAVE_ROOT, f"raw_cindex_per_time_run{i:02d}.csv")
        pd.DataFrame(raw_rows_auc).to_csv(raw_auc_path, index=False)
        pd.DataFrame(raw_rows_cidx).to_csv(raw_cidx_path, index=False)
        print(f"✅ 저장 완료: run{i:02d}")

        # -----------------------------
        # Plots
        # -----------------------------
        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(
                time_points, list(results_dict[label]['mean_auc_train'].values()),
                yerr=list(results_dict[label]['std_auc_train'].values()),
                fmt='--o', capsize=4, label=f"{label} - AUC Train"
            )
            plt.errorbar(
                time_points, list(results_dict[label]['mean_auc_val'].values()),
                yerr=list(results_dict[label]['std_auc_val'].values()),
                fmt='-o', capsize=4, label=f"{label} - AUC Val"
            )
        plt.title(f"AUC (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("AUC")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_auc_run{i:02d}.png"))
        plt.close()

        plt.figure(figsize=(10, 5))
        for label in feature_sets:
            plt.errorbar(
                time_points, list(results_dict[label]['mean_cidx_train'].values()),
                yerr=list(results_dict[label]['std_cidx_train'].values()),
                fmt='--s', capsize=4, label=f"{label} - C-index Train"
            )
            plt.errorbar(
                time_points, list(results_dict[label]['mean_cidx_val'].values()),
                yerr=list(results_dict[label]['std_cidx_val'].values()),
                fmt='-s', capsize=4, label=f"{label} - C-index Val"
            )
        plt.title(f"C-index (Run {i:02d})")
        plt.xlabel("Time (Months)")
        plt.ylabel("C-index")
        plt.ylim(0.1, 1.0)
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_ROOT, f"plot_cindex_run{i:02d}.png"))
        plt.close()
