In [None]:
# Mixture Stretched Exponential Survival with Debugging and NaN Handling

import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
import os

# ‚úÖ Ï†ÑÏó≠ Ïû¨ÌòÑÏÑ± Í≥†Ï†ïÏö© Í∏∞Î≥∏ ÏãúÎìú
BASE_SEED = 20250903
# Seed Í≥†Ï†ï
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

# Î™®Îç∏ Ï†ïÏùò
class MixtureStretchedExponentialSurvival(nn.Module):
    def __init__(self, input_dim, num_components=2):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.pi_layer = nn.Linear(64, num_components)
        self.lam_layer = nn.Linear(64, num_components)
        self.alpha_layer = nn.Linear(64, num_components)

    def forward(self, x):
        h = self.backbone(x)
        pi = F.softmax(self.pi_layer(h), dim=1)
        lam = F.softplus(self.lam_layer(h)) + 1e-3
        a = F.softplus(self.alpha_layer(h)) + 1e-3
        return pi, lam, a

# Loss Ï†ïÏùò
def mixture_stretched_nll(t, e, pi, lam, a, eps=1e-8):
    t = t.view(-1, 1)
    t_a = torch.pow(t + eps, a)
    S_k = torch.exp(-lam * t_a)
    f_k = lam * a * torch.pow(t + eps, a - 1) * S_k
    f = torch.sum(pi * f_k, dim=1) + eps
    S = torch.sum(pi * S_k, dim=1) + eps
    loglik = e * torch.log(f) + (1 - e) * torch.log(S)
    return -loglik.mean()

# ÏÉùÏ°¥Í≥°ÏÑ† ÏòàÏ∏°
@torch.no_grad()
def predict_survival(model, x, times):
    model.eval()
    pi, lam, a = model(x)
    surv = []
    for t in times:
        t_tensor = torch.tensor([t], dtype=torch.float32, device=x.device)
        t_a = torch.pow(t_tensor + 1e-8, a)
        S_k = torch.exp(-lam * t_a)
        S = torch.sum(pi * S_k, dim=1)
        surv.append(S.cpu().numpy())
    return np.vstack(surv)

# AUC Í≥ÑÏÇ∞
def calc_auc(surv_arr, y_df, times):
    aucs = {}
    for i, t in enumerate(times):
        true = ((y_df["event"] == 1) & (y_df["time"] <= t)).astype(int)
        pred = 1 - surv_arr[i, :]
        try:
            aucs[t] = roc_auc_score(true, pred)
        except:
            aucs[t] = np.nan
    return aucs

# ÏïàÏ†ïÏ†ÅÏù∏ C-index Í≥ÑÏÇ∞
def safe_concordance_index(times, risks, events):
    times = np.asarray(times)
    risks = np.asarray(risks)
    events = np.asarray(events)
    mask = ~(np.isnan(times) | np.isnan(risks) | np.isnan(events))
    if np.sum(mask) < 2:
        print("‚ö†Ô∏è Too few valid samples for C-index:", np.sum(mask))
        return np.nan
    if np.std(risks[mask]) < 1e-6:
        print("‚ö†Ô∏è Low risk variance, skipping C-index")
        return np.nan
    return concordance_index(times[mask], risks[mask], events[mask])


os.makedirs("./survival_model/mixture_non_fix/models", exist_ok=True)
device = torch.device('cpu')

BASE_GROUPS = ["beit0"]
N_RUNS = 30
time_points = [12, 24, 36, 48, 60, 72]
DEVICE = torch.device('cpu')

# ‚úÖ Í∑∏Î£πÎ≥Ñ img_cols Ï†ïÏùò
GROUP_IMG_COLS = {
    "n4_30_30": [
        "feat_10", "feat_15", "feat_25", "feat_121", "feat_123", "feat_125", "feat_143", "feat_152", "feat_163", "feat_167",
        "feat_169", "feat_181", "feat_194", "feat_203", "feat_210", "feat_220", "feat_240", "feat_255", "feat_289", "feat_309",
        "feat_328", "feat_352", "feat_361", "feat_368", "feat_378", "feat_389", "feat_402", "feat_407", "feat_420", "feat_439",
        "feat_451", "feat_468", "feat_498", "feat_507", "feat_514", "feat_560", "feat_565", "feat_576", "feat_578", "feat_605",
        "feat_617", "feat_633", "feat_653", "feat_656", "feat_666", "feat_710", "feat_747"
    ],
    "n5_30_30": [
        "feat_2", "feat_55", "feat_80", "feat_107", "feat_109", "feat_137", "feat_173", "feat_209", "feat_223", "feat_327",
        "feat_374", "feat_391", "feat_499", "feat_554", "feat_577", "feat_583", "feat_657", "feat_715"
    ],
    "n6_30_30": [
        "feat_213", "feat_266", "feat_215"
    ],
    "n7_30_30": [
        "feat_436", "feat_519"
    ],
}

# Í≥µÌÜµ clinical Î≥ÄÏàò
cont_cols = ['Age']
cat_cols  = ['pathology', 'stage0']

for base_group in BASE_GROUPS:
    print(f"\n\n============================")
    print(f"üìÅ BEiT Í∑∏Î£π Ïã§Ìñâ Ï§ë: {base_group}")
    print(f"============================")

    # üîÅ n4/n5/n6/n7 Í∑∏Î£πÏùÑ ÏàúÌöå
    for GROUP, img_cols in GROUP_IMG_COLS.items():
        print(f"\n========== GROUP: {GROUP} ==========")

        SAVE_ROOT_BASE = f"./survival_model/mixture_non_fix/non_nest/{base_group}/results/generalization/test0/dl0/{GROUP}"
        os.makedirs(SAVE_ROOT_BASE, exist_ok=True)
        SAVE_ROOT = SAVE_ROOT_BASE

        for i in range(1, N_RUNS + 1):
            fname = f"dh11_run{i:02d}.csv"
            csv_path = f"./deephit/{base_group}/test/dl0/{GROUP}/{fname}"
            print(f"\nüöÄ Ïã§Ìñâ Ï§ë: {base_group} - {GROUP} - {fname}")

            if not os.path.exists(csv_path):
                print(f"‚ö†Ô∏è ÌååÏùº ÏóÜÏùå, Ïä§ÌÇµ: {csv_path}")
                continue

            # === Ïó¨Í∏∞ÏÑúÎ∂ÄÌÑ∞Îäî Í∏∞Ï°¥ i-loop ÎÇ¥Î∂Ä ÏΩîÎìúÏôÄ ÎèôÏùº ===
            df_all = pd.read_csv(csv_path)
            df_all['event'] = df_all['survival'].astype(bool)
            df_all['time']  = df_all['fu_date'].astype(np.float32)

            feature_sets = {
                'Image only': (img_cols, []),
                'Clinical only': ([], cont_cols + cat_cols),
                'Image + Clinical': (img_cols, cont_cols + cat_cols)
            }

            results_dict = {}
            raw_rows_auc, raw_rows_cidx = [], []

            for label, (img_part, clinical_part) in feature_sets.items():
                print(f"\nüìå Feature Set: {label}")
                auc_train_list, auc_val_list = [], []
                cidx_train_list, cidx_val_list = [], []

                for run in range(N_RUNS):
                    seed = BASE_SEED + run
                    set_seed(seed)

                    X_df = df_all[img_part + clinical_part].copy()
                    y_df = df_all[['time', 'event']].copy()

                    X_train_df, X_val_df, y_train, y_val = train_test_split(
                        X_df, y_df,
                        test_size=0.3,
                        random_state=seed  # ‚úÖ BASE_SEED+run ÏÇ¨Ïö©
                    )

                    transformers = []
                    if img_part:
                        transformers.append(('img', StandardScaler(), img_part))
                    cont = [c for c in clinical_part if c in cont_cols]
                    cat  = [c for c in clinical_part if c in cat_cols]
                    if cont:
                        transformers.append(('cont', StandardScaler(), cont))
                    if cat:
                        transformers.append(('cat', OneHotEncoder(sparse_output=False,
                                                                  handle_unknown='ignore'),
                                             cat))

                    ct = ColumnTransformer(transformers)
                    X_train = ct.fit_transform(X_train_df)
                    X_val   = ct.transform(X_val_df)

                    X_train = pd.DataFrame(X_train).fillna(0).values.astype(np.float32)
                    X_val   = pd.DataFrame(X_val).fillna(0).values.astype(np.float32)

                    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(DEVICE)
                    X_val_tensor   = torch.tensor(X_val,   dtype=torch.float32).to(DEVICE)
                    t_train = torch.tensor(y_train['time'].values,  dtype=torch.float32).to(DEVICE)
                    e_train = torch.tensor(y_train['event'].values, dtype=torch.float32).to(DEVICE)

                    model = MixtureStretchedExponentialSurvival(
                        input_dim=X_train.shape[1],
                        num_components=2
                    ).to(DEVICE)
                    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

                    best_val_loss = float('inf')
                    patience, patience_counter = 10, 0
                    best_model_state = None

                    for epoch in range(1000):
                        model.train()
                        optimizer.zero_grad()
                        pi, lam, a = model(X_train_tensor)
                        loss = mixture_stretched_nll(t_train, e_train, pi, lam, a)
                        loss.backward()
                        optimizer.step()

                        if loss.item() < best_val_loss - 1e-6:
                            best_val_loss = loss.item()
                            best_model_state = model.state_dict()
                            patience_counter = 0
                        else:
                            patience_counter += 1
                            if patience_counter >= patience:
                                break

                    if best_model_state:
                        model.load_state_dict(best_model_state)

                    surv_train = predict_survival(model, X_train_tensor, time_points)
                    surv_val   = predict_survival(model, X_val_tensor,   time_points)

                    auc_train = calc_auc(surv_train, y_train.reset_index(drop=True), time_points)
                    auc_val   = calc_auc(surv_val,   y_val.reset_index(drop=True),   time_points)
                    auc_train_list.append(auc_train)
                    auc_val_list.append(auc_val)

                    risk_train = surv_train.T
                    risk_val   = surv_val.T
                    cidx_train = [safe_concordance_index(y_train['time'], risk_train[:, j], y_train['event'])
                                  for j in range(len(time_points))]
                    cidx_val   = [safe_concordance_index(y_val['time'],   risk_val[:,   j], y_val['event'])
                                  for j in range(len(time_points))]
                    cidx_train_list.append(cidx_train)
                    cidx_val_list.append(cidx_val)

                    for j, t in enumerate(time_points):
                        raw_rows_auc.append({
                            "Feature Set": label, "Run": run,
                            "Time (Months)": t,
                            "AUC (Train)": auc_train[t],
                            "AUC (Val)":   auc_val[t],
                            "Scope": "Time-wise"
                        })
                        raw_rows_cidx.append({
                            "Feature Set": label, "Run": run,
                            "Time (Months)": t,
                            "C-index (Train)": cidx_train[j],
                            "C-index (Val)":   cidx_val[j],
                            "Scope": "Time-wise"
                        })

                    raw_rows_auc.append({
                        "Feature Set": label, "Run": run,
                        "Time (Months)": "Overall",
                        "AUC (Train)": np.nanmean(list(auc_train.values())),
                        "AUC (Val)":   np.nanmean(list(auc_val.values())),
                        "Scope": "Overall"
                    })
                    raw_rows_cidx.append({
                        "Feature Set": label, "Run": run,
                        "Time (Months)": "Overall",
                        "C-index (Train)": np.nanmean(cidx_train),
                        "C-index (Val)":   np.nanmean(cidx_val),
                        "Scope": "Overall"
                    })

                results_dict[label] = {
                    'mean_auc_train': {t: np.nanmean([r[t] for r in auc_train_list]) for t in time_points},
                    'mean_auc_val':   {t: np.nanmean([r[t] for r in auc_val_list])   for t in time_points},
                    'std_auc_train':  {t: np.nanstd([r[t] for r in auc_train_list])  for t in time_points},
                    'std_auc_val':    {t: np.nanstd([r[t] for r in auc_val_list])    for t in time_points},
                    'mean_cidx_train':{t: np.nanmean([r[j] for r in cidx_train_list]) for j, t in enumerate(time_points)},
                    'mean_cidx_val':  {t: np.nanmean([r[j] for r in cidx_val_list])   for j, t in enumerate(time_points)},
                    'std_cidx_train': {t: np.nanstd([r[j] for r in cidx_train_list])  for j, t in enumerate(time_points)},
                    'std_cidx_val':   {t: np.nanstd([r[j] for r in cidx_val_list])    for j, t in enumerate(time_points)}
                }

            # Í≤∞Í≥º Ï†ÄÏû• (GROUP + i Ï°∞Ìï©Î≥Ñ)
            raw_auc_path  = os.path.join(SAVE_ROOT, f"raw_auc_per_time_run{i:02d}.csv")
            raw_cidx_path = os.path.join(SAVE_ROOT, f"raw_cindex_per_time_run{i:02d}.csv")
            pd.DataFrame(raw_rows_auc).to_csv(raw_auc_path, index=False)
            pd.DataFrame(raw_rows_cidx).to_csv(raw_cidx_path, index=False)
            print(f"‚úÖ Ï†ÄÏû• ÏôÑÎ£å: GROUP={GROUP}, run{i:02d}")

            # ÏãúÍ∞ÅÌôî Ï†ÄÏû• (AUC)
            plt.figure(figsize=(10, 5))
            for label in feature_sets:
                plt.errorbar(time_points, list(results_dict[label]['mean_auc_train'].values()),
                             yerr=list(results_dict[label]['std_auc_train'].values()),
                             fmt='--o', capsize=4, label=f"{label} - AUC Train")
                plt.errorbar(time_points, list(results_dict[label]['mean_auc_val'].values()),
                             yerr=list(results_dict[label]['std_auc_val'].values()),
                             fmt='-o', capsize=4, label=f"{label} - AUC Val")
            plt.title(f"AUC (GROUP {GROUP}, File run{i:02d})")
            plt.xlabel("Time (Months)")
            plt.ylabel("AUC")
            plt.ylim(0.1, 1.0)
            plt.grid(True)
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(SAVE_ROOT, f"plot_auc_run{i:02d}.png"))
            plt.close()

            # ÏãúÍ∞ÅÌôî Ï†ÄÏû• (C-index)
            plt.figure(figsize=(10, 5))
            for label in feature_sets:
                plt.errorbar(time_points, list(results_dict[label]['mean_cidx_train'].values()),
                             yerr=list(results_dict[label]['std_cidx_train'].values()),
                             fmt='--s', capsize=4, label=f"{label} - C-index Train")
                plt.errorbar(time_points, list(results_dict[label]['mean_cidx_val'].values()),
                             yerr=list(results_dict[label]['std_cidx_val'].values()),
                             fmt='-s', capsize=4, label=f"{label} - C-index Val")
            plt.title(f"C-index (GROUP {GROUP}, File run{i:02d})")
            plt.xlabel("Time (Months)")
            plt.ylabel("C-index")
            plt.ylim(0.1, 1.0)
            plt.grid(True)
            plt.legend()
            plt.tight_layout()
            plt.savefig(os.path.join(SAVE_ROOT, f"plot_cindex_run{i:02d}.png"))
            plt.close()
