<a href="https://colab.research.google.com/github/osun24/nsclc-adj-chemo/blob/main/TorchSurv_DeepSurv_with_Optuna.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Install necessary packages
!pip install torchsurv scikit-survival
!pip install optuna optuna-dashboard scikit-survival portpicker

# Import required packages
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sksurv.metrics import concordance_index_censored

# (Optional) Mount Google Drive if you plan to load/save files there
from google.colab import drive
drive.mount('/content/drive')

from optuna_dashboard import run_server
import threading, time, portpicker
from google.colab import output

PORT = portpicker.pick_unused_port()

def _serve():
    run_server("sqlite:///deepsurv_optuna.db", host="0.0.0.0", port=PORT)

threading.Thread(target=_serve, daemon=True).start()
time.sleep(2)
print("Dashboard:", output.eval_js(f"google.colab.kernel.proxyPort({PORT}, {{'cache': false}})"))


Mounted at /content/drive


Exception in thread Thread-5 (_serve):
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/sqlalchemy/engine/base.py", line 1967, in _exec_single_context
    self.dialect.do_execute(
  File "/usr/local/lib/python3.12/dist-packages/sqlalchemy/engine/default.py", line 951, in do_execute
    cursor.execute(statement, parameters)
sqlite3.OperationalError: no such table: version_info

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/optuna/storages/_rdb/storage.py", line 77, in _create_scoped_session
    yield session
  File "/usr/local/lib/python3.12/dist-packages/optuna/storages/_rdb/storage.py", line 1046, in _init_version_info_model
    version_info = models.VersionInfoModel.find(session)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/optuna/storages/_rdb/models.py", line 591, in find
    version_

Dashboard: https://40159-gpu-a100-s-ux6s319d4nv1-f.us-central1-1.prod.colab.dev


In [3]:
# ============================================================
# Colab-ready single cell: DeepSurv + Optuna HPO + Dashboard
# - Keeps ONLY requested clinical vars + genes with Prop==1
# - Sorts Train/Val by OS_MONTHS/OS_STATUS (desc)
# - Standardizes using TRAIN-only (applies to VAL); after HPO
#   restandardizes on TRAIN+VAL and evaluates TEST C-index
# - Optuna + Successive Halving pruner
# - Optuna Dashboard (proxied URL printed)  [fixed run_server args]
# - Encodes architectures as strings to avoid Optuna warning
# ============================================================

# ---------- (Optional) Mount Google Drive ----------
from google.colab import drive
drive.mount('/content/drive')

# ---------- Imports ----------
import os, math, copy, warnings, random, gc, time, threading
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler

# Cox loss: Prefer torchsurv if available; fallback to Efron implementation
try:
    from torchsurv.loss.cox import neg_partial_log_likelihood
    _HAS_TORCHSURV = True
except Exception:
    _HAS_TORCHSURV = False

from sklearn.preprocessing import StandardScaler
from sksurv.metrics import concordance_index_censored

import optuna
from optuna.pruners import SuccessiveHalvingPruner

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")

# Reproducibility
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ============================================================
# Cox loss fallback (Efron) if torchsurv isn't available
# ============================================================
def _cox_negloglik_efron(pred, event, time):
    eta = pred.reshape(-1)
    e = event.to(torch.float32).reshape(-1)
    t = time.reshape(-1)

    order = torch.argsort(t, descending=True)
    t = t[order]; e = e[order]; eta = eta[order]
    exp_eta = torch.exp(eta)
    cum_exp = torch.cumsum(exp_eta, dim=0)

    uniq_mask = torch.ones_like(t, dtype=torch.bool)
    uniq_mask[1:] = t[1:] != t[:-1]
    idxs = torch.nonzero(uniq_mask, as_tuple=False).reshape(-1)
    idxs = torch.cat([idxs, torch.tensor([len(t)], device=t.device)])

    nll = torch.tensor(0.0, device=t.device)
    for k in range(len(idxs)-1):
        start, end = idxs[k].item(), idxs[k+1].item()
        e_slice = e[start:end]
        d = int(e_slice.sum().item())
        if d == 0: continue
        eta_events = eta[start:end][e_slice.bool()]
        exp_events = exp_eta[start:end][e_slice.bool()]
        s_eta = eta_events.sum()
        risk_sum = cum_exp[end-1]
        s_exp = exp_events.sum()
        eps = 1e-12
        log_terms = 0.0
        for j in range(d):
            log_terms = log_terms + torch.log(risk_sum - (j / d) * s_exp + eps)
        nll = nll - (s_eta - log_terms)
    return nll / t.numel()

def cox_negloglik(pred, event, time):
    if _HAS_TORCHSURV:
        return neg_partial_log_likelihood(pred, event, time, reduction='mean')
    return _cox_negloglik_efron(pred, event, time)

# ============================================================
# Model, Dataset, Sampler, Utilities
# ============================================================
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events.astype(bool), dtype=torch.bool)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx]

class EventBalancedBatchSampler(Sampler):
    def __init__(self, events_numpy, batch_size, seed=0):
        events = np.asarray(events_numpy).astype(bool)
        self.pos_idx = np.where(events)[0]
        self.neg_idx = np.where(~events)[0]
        assert len(self.pos_idx) > 0, "No events in training set — cannot balance batches."
        self.bs = int(batch_size)
        self.rng = np.random.default_rng(seed)
    def __iter__(self):
        pos = self.rng.permutation(self.pos_idx)
        neg = self.rng.permutation(self.neg_idx)
        n_total = len(pos) + len(neg)
        n_batches = math.ceil(n_total / self.bs)
        pi = ni = 0
        for _ in range(n_batches):
            take_pos = 1 if pi < len(pos) else 0
            avail_neg = max(0, len(neg) - ni)
            take_neg = min(self.bs - take_pos, avail_neg)
            need = self.bs - (take_pos + take_neg)
            extra_pos = min(need, max(0, len(pos) - (pi + take_pos)))
            take_pos += extra_pos
            batch = np.concatenate([pos[pi:pi+take_pos], neg[ni:ni+take_neg]])
            pi += take_pos; ni += take_neg
            if batch.size == 0: break
            self.rng.shuffle(batch)
            yield batch.tolist()
    def __len__(self):
        return math.ceil((len(self.pos_idx) + len(self.neg_idx)) / self.bs)

def make_optimizer(model, lr, wd):
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        if name.endswith('bias'):
            no_decay.append(p); continue
        if (last_linear is not None) and (p is last_linear.weight):
            no_decay.append(p); continue
        decay.append(p)
    param_groups = [{'params': decay, 'weight_decay': wd},
                    {'params': no_decay, 'weight_decay': 0.0}]
    return optim.AdamW(param_groups, lr=lr)

def set_dropout_p(model, p):
    for m in model.modules():
        if isinstance(m, nn.Dropout): m.p = float(p)

def set_weight_decay(optimizer, wd):
    for g in optimizer.param_groups: g['weight_decay'] = float(wd)

def l1_penalty_first_layer(model):
    for m in model.modules():
        if isinstance(m, nn.Linear): return m.weight.abs().sum()
    return torch.tensor(0.0, device=next(model.parameters()).device)

@torch.no_grad()
def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    for x, t, e in dataloader:
        y = torch.clamp(model(x.to(device)), -20, 20)
        preds.append(y.cpu().numpy().ravel())
        times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    times = np.concatenate(times); events = np.concatenate(events)
    if np.isnan(preds).any(): return -np.inf
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def train_one_epoch(model, optimizer, dataloader, device, l1_lambda=0.0, epoch=0, warmup_epochs=20):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))
    loss_sum, n_seen = 0.0, 0
    for x, t, e in dataloader:
        if e.sum().item() == 0:  # safety (shouldn't happen with balanced sampler)
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)
        optimizer.zero_grad(set_to_none=True)
        out = torch.clamp(model(x), -20, 20)
        loss = cox_negloglik(out, e, t)
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        n_seen += x.size(0)
    return {'avg_loss': loss_sum / max(n_seen, 1), 'warm': warm}

def full_risk_set_step(model, optimizer, ds, device, l1_lambda=0.0, warm=1.0):
    model.train()
    X_all = ds.x.to(device); t_all = ds.time.to(device); e_all = ds.event.to(device)
    optimizer.zero_grad(set_to_none=True)
    out_all = torch.clamp(model(X_all), -20, 20)
    loss_full = cox_negloglik(out_all, e_all, t_all)
    if l1_lambda > 0:
        loss_full = loss_full + (l1_lambda * warm) * l1_penalty_first_layer(model)
    loss_full.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
    optimizer.step()
    return float(loss_full.detach().cpu().item())

# ============================================================
# Data loading & preprocessing
# ============================================================
# Default CSV paths (adjust if needed)
TRAIN_CSV = "/content/drive/MyDrive/affyfRMATrain.csv"
VALID_CSV = "/content/drive/MyDrive/affyfRMAValidation.csv"
TEST_CSV  = "/content/drive/MyDrive/affyfRMATest.csv"

# Genes list path (uploaded or Drive)
GENES_CSV = "/mnt/data/Genes.csv"
if not os.path.exists(GENES_CSV):
    if os.path.exists("/content/Genes.csv"):
        GENES_CSV = "/content/Genes.csv"
    elif os.path.exists("/content/drive/MyDrive/Genes.csv"):
        GENES_CSV = "/content/drive/MyDrive/Genes.csv"
print("Genes.csv path:", GENES_CSV)

# Clinical variables to KEEP (exact names)
CLINICAL_VARS = [
    "Adjuvant Chemo","Age","IS_MALE",
    "Stage_IA","Stage_IB","Stage_II","Stage_III",
    "Histology_Adenocarcinoma","Histology_Large Cell Carcinoma","Histology_Squamous Cell Carcinoma",
    "Race_African American","Race_Asian","Race_Caucasian","Race_Native Hawaiian or Other Pacific Islander","Race_Unknown",
    "Smoked?_No","Smoked?_Unknown","Smoked?_Yes"
]

def load_genes_list(genes_csv):
    g = pd.read_csv(genes_csv)
    if not {"Gene","Prop"}.issubset(set(g.columns)):
        raise ValueError(f"Genes.csv must contain 'Gene' and 'Prop' columns. Found: {list(g.columns)}")
    genes = g.loc[g["Prop"] == 1, "Gene"].astype(str).tolist()
    print(f"[Genes] Selected {len(genes)} genes with Prop == 1")
    return genes

def coerce_survival_cols(df):
    # Map to integers {0,1}
    if df["OS_STATUS"].dtype == object:
        df["OS_STATUS"] = df["OS_STATUS"].replace({"DECEASED":1,"LIVING":0,"Dead":1,"Alive":0}).astype(int)
    else:
        df["OS_STATUS"] = pd.to_numeric(df["OS_STATUS"], errors="coerce").fillna(0).astype(int)
    df["OS_MONTHS"] = pd.to_numeric(df["OS_MONTHS"], errors="coerce").fillna(0.0).astype(float)
    return df

def preprocess_split(df, clinical_vars, gene_names):
    if "Adjuvant Chemo" in df.columns:
        df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
    for col in ["Adjuvant Chemo","IS_MALE"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int)
    df = coerce_survival_cols(df)
    keep_cols = [c for c in clinical_vars if c in df.columns] + [g for g in gene_names if g in df.columns]
    missing_clin = [c for c in clinical_vars if c not in df.columns]
    if missing_clin:
        print(f"[WARN] Missing clinical columns: {missing_clin}")
    if len(keep_cols) == 0:
        raise ValueError("No feature columns found after filtering clinical+genes.")
    cols = ["OS_STATUS","OS_MONTHS"] + keep_cols
    return df[cols].copy()

# Load CSVs
train_raw = pd.read_csv(TRAIN_CSV)
valid_raw = pd.read_csv(VALID_CSV)
test_raw  = pd.read_csv(TEST_CSV)

# Load genes (Prop==1)
GENE_LIST = load_genes_list(GENES_CSV)

# Reduce to requested columns per split
train_df = preprocess_split(train_raw, CLINICAL_VARS, GENE_LIST)
valid_df = preprocess_split(valid_raw, CLINICAL_VARS, GENE_LIST)
test_df  = preprocess_split(test_raw,  CLINICAL_VARS, GENE_LIST)

# Ensure consistent columns across splits (intersection)
feat_candidates = [c for c in (CLINICAL_VARS + GENE_LIST)
                   if c in train_df.columns and c in valid_df.columns and c in test_df.columns]
if len(feat_candidates) == 0:
    raise ValueError("After filtering, no common features across train/val/test.")
print(f"[Features] Using {len(feat_candidates)} common features.")

# Sort Train/Val by event time & status (descending)
train_df = train_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)
valid_df = valid_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

# Build arrays & TRAIN-only standardization (apply to VAL)
X_train = train_df[feat_candidates].values.astype(np.float32)
X_valid = valid_df[feat_candidates].values.astype(np.float32)

train_medians = np.nanmedian(X_train, axis=0)
X_train = np.where(np.isnan(X_train), train_medians, X_train)
X_valid = np.where(np.isnan(X_valid), train_medians, X_valid)

scaler_tv = StandardScaler().fit(X_train)
X_train = scaler_tv.transform(X_train).astype(np.float32)
X_valid = scaler_tv.transform(X_valid).astype(np.float32)

ytr_time = train_df["OS_MONTHS"].values.astype(np.float32)
ytr_event = train_df["OS_STATUS"].values.astype(int)
yva_time = valid_df["OS_MONTHS"].values.astype(np.float32)
yva_event = valid_df["OS_STATUS"].values.astype(int)

BATCH_SIZE = 64
train_ds = SurvivalDataset(X_train, ytr_time, ytr_event)
valid_ds = SurvivalDataset(X_valid, yva_time, yva_event)

train_sampler = EventBalancedBatchSampler(ytr_event, BATCH_SIZE, seed=42)
train_loader  = DataLoader(train_ds, batch_sampler=train_sampler, num_workers=0)
train_eval_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
valid_loader      = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

in_dim = X_train.shape[1]
print("Input dim:", in_dim)

# ============================================================
# Optuna Objective & Study  (arch encoded as strings -> parsed)
# ============================================================# ============================================================
# Optuna Objective & Study — Expanded Space to Reduce Overfitting
#   * Architectures include smaller, narrower and bottlenecks
#   * Stronger regularizers: higher dropout, input dropout, L1/L2
#   * Optionally apply WD to final layer
#   * Batch size, grad clip, scheduler, epochs per trial
#   * Input Gaussian noise
# ============================================================

# Wider set of architectures (strings -> parsed lists)
ARCH_CHOICES = (
    # very small
    "16", "32",
    # small / medium singles
    "64", "128", "256",
    # conservative big single
    "512",
    # 2-layer bottlenecks and symmetric
    "32-16", "32-32", "64-32", "64-64",
    "128-64", "128-128", "256-128", "256-256",
    "512-256", "512-512",
    # 3-layer (narrowing)
    "64-32-16", "128-64-32", "256-128-64"
)

def layers_from_arch(arch_str: str):
    return [int(x) for x in arch_str.split("-") if x.strip()]

def suggest_hparams(trial):
    # ----- STATIC spaces only -----
    # Main-gene budget: choose from a fixed set (truncate to MAX_GENES at runtime)
    TOPK_MAIN_CHOICES = tuple([k for k in (32, 64, 128, 256, 512, 800, MAX_GENES) if k <= MAX_GENES])
    top_k_genes = trial.suggest_categorical("top_k_genes", TOPK_MAIN_CHOICES)

    # Interaction-gene budget: FIXED choices; we’ll clamp to <= top_k_genes after sampling
    TOPK_INTER_CHOICES = tuple([k for k in (0, 16, 32, 64, 128, 256, 512) if k <= MAX_GENES])
    top_k_inter_raw = trial.suggest_categorical("top_k_inter", TOPK_INTER_CHOICES)

    # net & regularizers
    arch = trial.suggest_categorical("arch", ARCH_CHOICES)
    dropout = trial.suggest_float("dropout", 0.10, 0.70)
    input_dropout = trial.suggest_float("input_dropout", 0.00, 0.30)
    noise_std = trial.suggest_float("noise_std", 0.0, 0.10)

    wd = trial.suggest_float("wd", 1e-6, 1e-1, log=True)
    apply_final_wd = trial.suggest_categorical("apply_final_wd", (0, 1))
    use_l1 = trial.suggest_categorical("use_l1", (0, 1))
    l1 = 0.0 if use_l1 == 0 else trial.suggest_float("l1", 1e-8, 3e-3, log=True)

    lr = trial.suggest_float("lr", 1e-5, 5e-4, log=True)
    sched = trial.suggest_categorical("sched", ("cosine", "cawr", "none"))
    if sched == "cawr":
        cawr_T0 = trial.suggest_int("cawr_T0", 16, 80, step=8)
        cawr_Tmult = trial.suggest_categorical("cawr_Tmult", (1, 2, 3))
    else:
        cawr_T0 = cawr_Tmult = None

    epochs = trial.suggest_int("epochs", 64, 384, step=32)
    batch_size = trial.suggest_categorical("batch_size", (32, 64, 128))
    grad_clip = trial.suggest_float("grad_clip", 1.0, 10.0)

    # Clamp interactions to main-gene budget (no dynamic space in Optuna, just post-hoc enforcement)
    top_k_inter = int(min(top_k_inter_raw, top_k_genes))

    return dict(
        arch=arch, top_k_genes=int(top_k_genes), top_k_inter=int(top_k_inter),
        dropout=dropout, input_dropout=input_dropout, noise_std=noise_std,
        wd=wd, apply_final_wd=int(apply_final_wd), use_l1=int(use_l1), l1=float(l1),
        lr=lr, sched=sched, cawr_T0=cawr_T0, cawr_Tmult=cawr_Tmult,
        epochs=int(epochs), batch_size=int(batch_size), grad_clip=float(grad_clip)
    )


# Warmups (as before)
MAX_EPOCHS_CAP = 512  # absolute cap (safety)
WARMUP_EPOCHS_L1 = 30
WARMUP_EPOCHS_DROPOUT = 30
WARMUP_EPOCHS_WD = 30
DROPOUT_START = 0.15
WD_START = 0.0

# Local helpers that add input dropout/noise and variable grad clip
def _apply_input_dropout(x, p):
    if p <= 0.0: return x
    # inverted dropout on features
    keep = 1.0 - p
    mask = torch.bernoulli(torch.full_like(x, keep))
    return x * mask / max(keep, 1e-6)

def train_one_epoch_reg(model, optimizer, dataloader, device,
                        l1_lambda=0.0, epoch=0, warmup_epochs=20,
                        input_dropout=0.0, noise_std=0.0, grad_clip=5.0):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))
    loss_sum, n_seen = 0.0, 0
    for x, t, e in dataloader:
        if e.sum().item() == 0:
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)

        # data-level regularization
        if input_dropout > 0.0:
            x = _apply_input_dropout(x, input_dropout)
        if noise_std > 0.0:
            x = x + noise_std * torch.randn_like(x)

        optimizer.zero_grad(set_to_none=True)
        out = torch.clamp(model(x), -20, 20)
        loss = cox_negloglik(out, e, t)
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        n_seen += x.size(0)
    return {'avg_loss': loss_sum / max(n_seen, 1), 'warm': warm}

def full_risk_set_step_reg(model, optimizer, ds, device,
                           l1_lambda=0.0, warm=1.0,
                           input_dropout=0.0, noise_std=0.0, grad_clip=5.0):
    model.train()
    X_all = ds.x.to(device); t_all = ds.time.to(device); e_all = ds.event.to(device)
    XX = X_all
    if input_dropout > 0.0:
        XX = _apply_input_dropout(XX, input_dropout)
    if noise_std > 0.0:
        XX = XX + noise_std * torch.randn_like(XX)

    optimizer.zero_grad(set_to_none=True)
    out_all = torch.clamp(model(XX), -20, 20)
    loss_full = cox_negloglik(out_all, e_all, t_all)
    if l1_lambda > 0:
        loss_full = loss_full + (l1_lambda * warm) * l1_penalty_first_layer(model)
    loss_full.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
    optimizer.step()
    return float(loss_full.detach().cpu().item())

# Optimizer that can optionally apply WD to the final layer
def make_optimizer_hpo(model, lr, wd, apply_final_wd=False):
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.endswith('bias'):
            no_decay.append(p); continue
        if (last_linear is not None) and (p is last_linear.weight) and not apply_final_wd:
            no_decay.append(p); continue
        decay.append(p)
    param_groups = [{'params': decay, 'weight_decay': float(wd)},
                    {'params': no_decay, 'weight_decay': 0.0}]
    return optim.AdamW(param_groups, lr=float(lr))

def objective(trial):
    hp = suggest_hparams(trial)
    layers = layers_from_arch(hp["arch"])

    # Build loaders with TRIAL-SPECIFIC batch size (smaller batches often regularize more)
    bs = int(hp["batch_size"])
    tr_sampler = EventBalancedBatchSampler(ytr_event, bs, seed=42)
    tr_loader = DataLoader(train_ds, batch_sampler=tr_sampler, num_workers=0)
    tr_eval_loader = DataLoader(train_ds, batch_size=bs, shuffle=False, num_workers=0)
    va_loader = DataLoader(valid_ds, batch_size=bs, shuffle=False, num_workers=0)

    # Model / optimizer
    model = DeepSurvMLP(in_dim, layers, dropout=hp["dropout"]).to(device)
    optimizer = make_optimizer_hpo(model, lr=hp["lr"], wd=hp["wd"],
                                   apply_final_wd=bool(hp["apply_final_wd"]))

    # Scheduler per trial
    epochs = int(min(hp["epochs"], MAX_EPOCHS_CAP))
    if hp["sched"] == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        def step_sched(epoch_idx): scheduler.step()
    elif hp["sched"] == "cawr":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=int(hp["cawr_T0"]), T_mult=int(hp["cawr_Tmult"])
        )
        def step_sched(epoch_idx): scheduler.step(epoch_idx + 1)
    else:
        scheduler = None
        def step_sched(epoch_idx): pass

    best_val_ci = -np.inf
    best_epoch = 0

    for epoch in range(epochs):
        # Warm-up schedules for dropout & WD
        frac_d = min(1.0, epoch / float(WARMUP_EPOCHS_DROPOUT))
        frac_w = min(1.0, epoch / float(WARMUP_EPOCHS_WD))
        set_dropout_p(model, DROPOUT_START + (hp['dropout'] - DROPOUT_START) * frac_d)
        set_weight_decay(optimizer, WD_START + (hp['wd'] - WD_START) * frac_w)

        # One epoch + full risk-set step with extra regularizers
        stats = train_one_epoch_reg(
            model, optimizer, tr_loader, device,
            l1_lambda=hp["l1"], epoch=epoch, warmup_epochs=WARMUP_EPOCHS_L1,
            input_dropout=hp["input_dropout"], noise_std=hp["noise_std"],
            grad_clip=hp["grad_clip"]
        )
        _ = full_risk_set_step_reg(
            model, optimizer, train_ds, device,
            l1_lambda=hp["l1"], warm=stats['warm'],
            input_dropout=hp["input_dropout"], noise_std=hp["noise_std"],
            grad_clip=hp["grad_clip"]
        )

        # Evaluate
        val_ci = evaluate_ci(model, va_loader, device)
        step_sched(epoch)

        # Report for pruning
        trial.report(val_ci, step=epoch)
        if val_ci > best_val_ci:
            best_val_ci = val_ci
            best_epoch = epoch + 1

        if trial.should_prune():
            # Clean up GPU mem before pruning
            del model, optimizer, scheduler
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            gc.collect()
            raise optuna.TrialPruned()

    trial.set_user_attr("best_epoch", int(best_epoch))

    # Clean up
    del model, optimizer, scheduler
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    gc.collect()
    return best_val_ci

# ---- Study / Sampler / Pruner (tune TPE a bit for broader exploration) ----
storage = "sqlite:///deepsurv_optuna.db"
study_name = "deepsurv_cox_hpo_overfit_reducer"

sampler = optuna.samplers.TPESampler(
    seed=42,
    multivariate=True,
    group=True,
    n_startup_trials=40,       # explore more before exploitation
    constant_liar=True,        # better parallel behavior if you run parallel
    consider_prior=True
)

# Hyperband/SH both work; Hyperband gives a bit more flexibility across resource levels
pruner = optuna.pruners.HyperbandPruner(min_resource=16, reduction_factor=3)

study = optuna.create_study(
    direction="maximize",       # single objective (Val CI)
    study_name=study_name,
    storage=storage,
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner
)


# ============================================================
# Launch Optuna Dashboard (Colab proxied URL printed)
#   (No unsupported args; previous TypeError fixed)
# ============================================================
try:
    from optuna_dashboard import run_server
    from google.colab import output
    import portpicker
    PORT = portpicker.pick_unused_port()
    def _start_dashboard():
        # NOTE: do NOT pass unsupported kwargs like reload/quiet
        run_server(storage, host="0.0.0.0", port=PORT)
    t = threading.Thread(target=_start_dashboard, daemon=True)
    t.start()
    time.sleep(2)
    dash_url = output.eval_js(f"google.colab.kernel.proxyPort({PORT}, {{'cache': false}})")
    print("Optuna Dashboard:", dash_url)
except Exception as ex:
    print("[Optuna Dashboard] Could not start dashboard automatically.", ex)
    print("You can run it locally with:  optuna-dashboard sqlite:///deepsurv_optuna.db")

# ============================================================
# Run Optimization
# ============================================================
N_TRIALS = 100  # adjust as needed
print(f"Starting optimization: {N_TRIALS} trials × up to {MAX_EPOCHS_CAP} epochs")
study.optimize(objective, n_trials=N_TRIALS, gc_after_trial=True)

print("\n[Best] Val CI:", study.best_value)
print("[Best] Params:", study.best_params)
print("[Best] Best epoch:", study.best_trial.user_attrs.get("best_epoch", MAX_EPOCHS_CAP))

# ============================================================
# Retrain on Train+Val with best hyperparams; evaluate Test
# - Combine Train+Val, sort, restandardize; apply to Test
# ============================================================
# Prepare Train+Val
trainval_df = pd.concat([train_df, valid_df], axis=0, ignore_index=True)
trainval_df = trainval_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

X_trv = trainval_df[feat_candidates].values.astype(np.float32)
y_trv_time = trainval_df["OS_MONTHS"].values.astype(np.float32)
y_trv_event = trainval_df["OS_STATUS"].values.astype(int)

# Median impute by Train+Val medians (for retraining phase)
trv_medians = np.nanmedian(X_trv, axis=0)
X_trv = np.where(np.isnan(X_trv), trv_medians, X_trv)

scaler_trv = StandardScaler().fit(X_trv)
X_trv = scaler_trv.transform(X_trv).astype(np.float32)

# Test set standardized with Train+Val scaler
X_test = test_df[feat_candidates].values.astype(np.float32)
X_test = np.where(np.isnan(X_test), trv_medians, X_test)
X_test = scaler_trv.transform(X_test).astype(np.float32)
y_te_time = test_df["OS_MONTHS"].values.astype(np.float32)
y_te_event = test_df["OS_STATUS"].values.astype(int)

# Loaders
BATCH_SIZE = 64
trv_ds = SurvivalDataset(X_trv, y_trv_time, y_trv_event)
te_ds  = SurvivalDataset(X_test, y_te_time, y_te_event)

trv_sampler = EventBalancedBatchSampler(y_trv_event, BATCH_SIZE, seed=7)
trv_loader  = DataLoader(trv_ds, batch_sampler=trv_sampler, num_workers=0)
trv_eval_loader = DataLoader(trv_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
te_loader  = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Build & train final model
best_hp = study.best_params
l1_final = float(best_hp.get("l1", 0.0))
best_layers = layers_from_arch(best_hp["arch"])
best_n_epochs = int(study.best_trial.user_attrs.get("best_epoch", MAX_EPOCHS_CAP))
best_n_epochs = max(16, min(best_n_epochs, MAX_EPOCHS_CAP))

model_final = DeepSurvMLP(in_dim, best_layers, dropout=best_hp["dropout"]).to(device)
opt_final = make_optimizer(model_final, lr=best_hp["lr"], wd=best_hp["wd"])
sched_final = torch.optim.lr_scheduler.CosineAnnealingLR(opt_final, T_max=best_n_epochs)

for epoch in range(best_n_epochs):
    frac_d = min(1.0, epoch / float(WARMUP_EPOCHS_DROPOUT))
    frac_w = min(1.0, epoch / float(WARMUP_EPOCHS_WD))
    set_dropout_p(model_final, DROPOUT_START + (best_hp['dropout'] - DROPOUT_START) * frac_d)
    set_weight_decay(opt_final, WD_START + (best_hp['wd'] - WD_START) * frac_w)

    stats = train_one_epoch(model_final, opt_final, trv_loader, device, l1_lambda=best_hp["l1"],
                            epoch=epoch, warmup_epochs=WARMUP_EPOCHS_L1)
    _ = full_risk_set_step(model_final, opt_final, trv_ds, device, l1_lambda=best_hp["l1"], warm=stats['warm'])
    sched_final.step()

# Evaluate
trainval_ci = evaluate_ci(model_final, trv_eval_loader, device)
test_ci = evaluate_ci(model_final, te_loader, device)
print(f"\n[Final] Train+Val CI: {trainval_ci:.4f}")
print(f"[Final] Test CI:      {test_ci:.4f}")

# (Optional) Save artifacts to Drive
OUT_DIR = "/content/drive/MyDrive/deepsurv_results_optuna"
os.makedirs(OUT_DIR, exist_ok=True)
torch.save(model_final.state_dict(), os.path.join(OUT_DIR, "deepsurv_best.pt"))
with open(os.path.join(OUT_DIR, "best_params.txt"), "w") as f:
    f.write(str(study.best_params))
print("Saved final model and best params to:", OUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
Genes.csv path: /content/drive/MyDrive/Genes.csv
[Genes] Selected 1555 genes with Prop == 1
[Features] Using 1573 common features.


  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})


Input dim: 1573


[I 2025-10-19 19:17:47,670] A new study created in RDB with name: deepsurv_cox_hpo_overfit_reducer
Bottle v0.13.4 server starting up (using WSGIRefServer())...
Listening on http://0.0.0.0:38099/
Hit Ctrl-C to quit.



Optuna Dashboard: https://38099-gpu-a100-s-2cm35yjghnl2l-b.asia-southeast1-0.prod.colab.dev
Starting optimization: 100 trials × up to 512 epochs


[W 2025-10-19 19:18:28,622] Trial 0 failed with parameters: {'arch': '128-128', 'dropout': 0.27473748411882515, 'input_dropout': 0.18355586841671384, 'wd': 4.982752357076453e-06, 'apply_final_wd': 1, 'use_l1': 1, 'l1': 1.240616395320886e-07, 'lr': 7.475992999956501e-05, 'sched': 'none', 'epochs': 128, 'batch_size': 128, 'grad_clip': 8.275576133048151, 'noise_std': 0.03046137691733707} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/optuna/study/_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipython-input-3928505025.py", line 536, in objective
    _ = full_risk_set_step_reg(
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3928505025.py", line 467, in full_risk_set_step_reg
    loss_full.backward()
  File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 647, in backward
    torch.autograd.backward(
  Fi

KeyboardInterrupt: 

In [None]:
# ==== Ready-to-run: Retrain on Train+Val with given best params, then eval on Test ====

# --- Use the reported best params & epoch (handles missing 'l1') ---
best_hp = {
    'arch': '512',
    'dropout': 0.385912835673872,
    'input_dropout': 0.061771953713056,  # not used in final fit below
    'wd': 3.3415205446076416e-06,
    'apply_final_wd': 0,                 # not used in final fit below
    'use_l1': 0,
    'lr': 6.945321990789577e-05,
    'sched': 'cosine',
    'epochs': 192,
    'batch_size': 32,
    'grad_clip': 8.251469896233349,      # not used in final fit below
    'noise_std': 0.09246940755563608     # not used in final fit below
}
best_n_epochs = 26
l1_final = float(best_hp.get("l1", 0.0))  # <- fix: default L1 when use_l1==0

# --- Build Train+Val, restandardize; transform Test ---
trainval_df = pd.concat([train_df, valid_df], axis=0, ignore_index=True)
trainval_df = trainval_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

X_trv = trainval_df[feat_candidates].values.astype(np.float32)
y_trv_time = trainval_df["OS_MONTHS"].values.astype(np.float32)
y_trv_event = trainval_df["OS_STATUS"].values.astype(int)

trv_medians = np.nanmedian(X_trv, axis=0)
X_trv = np.where(np.isnan(X_trv), trv_medians, X_trv)
scaler_trv = StandardScaler().fit(X_trv)
X_trv = scaler_trv.transform(X_trv).astype(np.float32)

X_test = test_df[feat_candidates].values.astype(np.float32)
X_test = np.where(np.isnan(X_test), trv_medians, X_test)
X_test = scaler_trv.transform(X_test).astype(np.float32)
y_te_time = test_df["OS_MONTHS"].values.astype(np.float32)
y_te_event = test_df["OS_STATUS"].values.astype(int)

# --- Datasets & Loaders ---
BATCH_SIZE = 64
trv_ds = SurvivalDataset(X_trv, y_trv_time, y_trv_event)
te_ds  = SurvivalDataset(X_test, y_te_time, y_te_event)

trv_sampler = EventBalancedBatchSampler(y_trv_event, BATCH_SIZE, seed=7)
trv_loader  = DataLoader(trv_ds, batch_sampler=trv_sampler, num_workers=0)
trv_eval_loader = DataLoader(trv_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
te_loader  = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# --- Model (arch parsed from string), Optim, Scheduler ---
def layers_from_arch(arch_str: str):
    return [int(x) for x in arch_str.split("-") if x.strip()]

in_dim = X_trv.shape[1]
best_layers = layers_from_arch(best_hp["arch"])

model_final = DeepSurvMLP(in_dim, best_layers, dropout=best_hp["dropout"]).to(device)
opt_final = make_optimizer(model_final, lr=best_hp["lr"], wd=best_hp["wd"])
sched_final = torch.optim.lr_scheduler.CosineAnnealingLR(opt_final, T_max=best_n_epochs)

# --- Warmups (reuse constants from earlier cell if present; else define sane defaults) ---
try:
    WARMUP_EPOCHS_DROPOUT
except NameError:
    WARMUP_EPOCHS_DROPOUT = 30
    WARMUP_EPOCHS_WD = 30
    WARMUP_EPOCHS_L1 = 30
    DROPOUT_START = 0.15
    WD_START = 0.0

# --- Train ---
for epoch in range(best_n_epochs):
    frac_d = min(1.0, epoch / float(WARMUP_EPOCHS_DROPOUT))
    frac_w = min(1.0, epoch / float(WARMUP_EPOCHS_WD))
    set_dropout_p(model_final, DROPOUT_START + (best_hp['dropout'] - DROPOUT_START) * frac_d)
    set_weight_decay(opt_final, WD_START + (best_hp['wd'] - WD_START) * frac_w)

    stats = train_one_epoch(model_final, opt_final, trv_loader, device, l1_lambda=l1_final,
                            epoch=epoch, warmup_epochs=WARMUP_EPOCHS_L1)
    _ = full_risk_set_step(model_final, opt_final, trv_ds, device, l1_lambda=l1_final, warm=stats['warm'])
    sched_final.step()

# --- Evaluate ---
trainval_ci = evaluate_ci(model_final, trv_eval_loader, device)
test_ci = evaluate_ci(model_final, te_loader, device)
print(f"\n[Final] Train+Val CI: {trainval_ci:.4f}")
print(f"[Final] Test CI:      {test_ci:.4f}")


[Final] Train+Val CI: 0.7981
[Final] Test CI:      0.6103


Multiobjective

In [7]:
# ============================================================
# Colab-ready SINGLE CELL
# DeepSurv + Multi-Objective Optuna (Val CI ↑, Gap ↓, Params ↓)
# with:
#   • ONLY clinical vars (given list) + genes with Prop==1 from Genes.csv
#   • Train/Val sorted by OS_MONTHS & OS_STATUS (desc)
#   • Train-only standardization applied to Val; after HPO, restandardize on Train+Val
#   • Treatment × Gene interactions (for Top-K_inter genes)
#   • Stabilized IPTW (propensity weights) to address ACT imbalance
#   • Event-balanced batches; dropout & WD warm-ups; gradient clipping; LR schedulers
#   • Optuna Dashboard (link printed)
#   • Final retraining on Train+Val with selected hyperparams + interactions + IPTW; Test CI reported
# ============================================================

# ---------- Installs (Colab) ----------
!pip -q install optuna optuna-dashboard scikit-survival portpicker

# ---------- (Optional) Mount Google Drive ----------
from google.colab import drive
drive.mount('/content/drive')

# ---------- Imports ----------
import os, math, copy, warnings, random, gc, time, threading
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler

# Prefer torchsurv if available; fallback to custom Efron
try:
    from torchsurv.loss.cox import neg_partial_log_likelihood
    _HAS_TORCHSURV = True
except Exception:
    _HAS_TORCHSURV = False

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv
from sksurv.linear_model import CoxPHSurvivalAnalysis

import optuna
from optuna.samplers import NSGAIISampler

import portpicker
from google.colab import output

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ============================================================
# Cox Losses
# ============================================================
def _cox_negloglik_efron(pred, event, time):
    eta = pred.reshape(-1)
    e = event.to(torch.float32).reshape(-1)
    t = time.reshape(-1)
    order = torch.argsort(t, descending=True)
    t = t[order]; e = e[order]; eta = eta[order]
    exp_eta = torch.exp(eta)
    cum_exp = torch.cumsum(exp_eta, dim=0)
    uniq_mask = torch.ones_like(t, dtype=torch.bool)
    uniq_mask[1:] = t[1:] != t[:-1]
    idxs = torch.nonzero(uniq_mask, as_tuple=False).reshape(-1)
    idxs = torch.cat([idxs, torch.tensor([len(t)], device=t.device)])
    nll = torch.tensor(0.0, device=t.device)
    for k in range(len(idxs)-1):
        start, end = idxs[k].item(), idxs[k+1].item()
        e_slice = e[start:end]
        d = int(e_slice.sum().item())
        if d == 0: continue
        eta_events = eta[start:end][e_slice.bool()]
        exp_events = torch.exp(eta[start:end])[e_slice.bool()]
        s_eta = eta_events.sum()
        risk_sum = cum_exp[end-1]
        s_exp = exp_events.sum()
        eps = 1e-12
        log_terms = 0.0
        for j in range(d):
            log_terms = log_terms + torch.log(risk_sum - (j / d) * s_exp + eps)
        nll = nll - (s_eta - log_terms)
    return nll / t.numel()

def cox_negloglik(pred, event, time):
    if _HAS_TORCHSURV:
        return neg_partial_log_likelihood(pred, event, time, reduction='mean')
    return _cox_negloglik_efron(pred, event, time)

# Weighted Breslow negative partial log-likelihood (IPTW)
def cox_negloglik_breslow_weighted(pred, event, time, weight=None):
    eta = pred.reshape(-1)
    e = event.to(torch.float32).reshape(-1)
    t = time.reshape(-1)
    if weight is None:
        weight = torch.ones_like(eta)
    else:
        weight = weight.reshape(-1)

    order = torch.argsort(t, descending=True)
    t = t[order]; e = e[order]; eta = eta[order]; w = weight[order]
    exp_eta = torch.exp(eta)
    w_exp = w * exp_eta
    cum_w_exp = torch.cumsum(w_exp, dim=0)

    uniq = torch.ones_like(t, dtype=torch.bool)
    uniq[1:] = t[1:] != t[:-1]
    idxs = torch.nonzero(uniq, as_tuple=False).reshape(-1)
    idxs = torch.cat([idxs, torch.tensor([len(t)], device=t.device)])

    nll = torch.tensor(0.0, device=t.device)
    eps = 1e-12
    for k in range(len(idxs)-1):
        s, eidx = idxs[k].item(), idxs[k+1].item()
        emask = e[s:eidx] > 0.5
        if emask.sum() == 0:
            continue
        w_events = w[s:eidx][emask]
        eta_events = eta[s:eidx][emask]
        s_eta = (w_events * eta_events).sum()
        denom = cum_w_exp[eidx-1]
        nll = nll - (s_eta - w_events.sum() * torch.log(denom + eps))
    return nll / (w.sum() + 1e-9)

def cox_loss(pred, event, time, weight=None):
    """Use weighted Breslow if weights provided; else Efron/torchsurv."""
    if weight is None:
        return cox_negloglik(pred, event, time)
    return cox_negloglik_breslow_weighted(pred, event, time, weight)

# ============================================================
# Model, Dataset (with weights), Sampler, Utils
# ============================================================
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events, weights=None):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events.astype(bool), dtype=torch.bool)
        if weights is None:
            self.w = torch.ones(len(self.x), dtype=torch.float32)
        else:
            self.w = torch.tensor(weights, dtype=torch.float32)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx], self.w[idx]

class EventBalancedBatchSampler(Sampler):
    def __init__(self, events_numpy, batch_size, seed=0):
        events = np.asarray(events_numpy).astype(bool)
        self.pos_idx = np.where(events)[0]
        self.neg_idx = np.where(~events)[0]
        assert len(self.pos_idx) > 0, "No events in training set — cannot balance batches."
        self.bs = int(batch_size)
        self.rng = np.random.default_rng(seed)
    def __iter__(self):
        pos = self.rng.permutation(self.pos_idx)
        neg = self.rng.permutation(self.neg_idx)
        n_total = len(pos) + len(neg)
        n_batches = math.ceil(n_total / self.bs)
        pi = ni = 0
        for _ in range(n_batches):
            take_pos = 1 if pi < len(pos) else 0
            avail_neg = max(0, len(neg) - ni)
            take_neg = min(self.bs - take_pos, avail_neg)
            need = self.bs - (take_pos + take_neg)
            extra_pos = min(need, max(0, len(pos) - (pi + take_pos)))
            take_pos += extra_pos
            batch = np.concatenate([pos[pi:pi+take_pos], neg[ni:ni+take_neg]])
            pi += take_pos; ni += take_neg
            if batch.size == 0: break
            self.rng.shuffle(batch)
            yield batch.tolist()
    def __len__(self):
        return math.ceil((len(self.pos_idx) + len(self.neg_idx)) / self.bs)

def make_optimizer_groups(model, lr, wd, apply_final_wd=False):
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        if name.endswith('bias'):
            no_decay.append(p); continue
        if (last_linear is not None) and (p is last_linear.weight) and not apply_final_wd:
            no_decay.append(p); continue
        decay.append(p)
    return optim.AdamW(
        [{'params': decay, 'weight_decay': float(wd)},
         {'params': no_decay, 'weight_decay': 0.0}],
        lr=float(lr)
    )

def set_dropout_p(model, p):
    for m in model.modules():
        if isinstance(m, nn.Dropout): m.p = float(p)

def set_weight_decay(optimizer, wd):
    for g in optimizer.param_groups: g['weight_decay'] = float(wd)

def l1_penalty_first_layer(model):
    for m in model.modules():
        if isinstance(m, nn.Linear): return m.weight.abs().sum()
    return torch.tensor(0.0, device=next(model.parameters()).device)

@torch.no_grad()
def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    for batch in dataloader:
        if len(batch) == 4:
            x, t, e, _ = batch
        else:
            x, t, e = batch
        y = torch.clamp(model(x.to(device)), -20, 20)
        preds.append(y.cpu().numpy().ravel())
        times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    times = np.concatenate(times); events = np.concatenate(events)
    if np.isnan(preds).any(): return -np.inf
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def evaluate_ci_grouped(model, X, t, e, group_mask):
    """Compute C-index within group_mask==1 and group_mask==0."""
    model.eval()
    with torch.no_grad():
        preds = model(torch.tensor(X, dtype=torch.float32, device=device)).cpu().numpy().ravel()
    res = {}
    for label, mask in [("ACT=1", group_mask.astype(bool)), ("ACT=0", ~group_mask.astype(bool))]:
        if mask.sum() >= 3:  # need a few samples
            ci = concordance_index_censored(e[mask].astype(bool), t[mask], preds[mask])[0]
            res[label] = float(ci)
        else:
            res[label] = np.nan
    return res

# ============================================================
# Data loading & preprocessing
# ============================================================
TRAIN_CSV = "/content/drive/MyDrive/affyfRMATrain.csv"
VALID_CSV = "/content/drive/MyDrive/affyfRMAValidation.csv"
TEST_CSV  = "/content/drive/MyDrive/affyfRMATest.csv"

GENES_CSV = "/mnt/data/Genes.csv"
if not os.path.exists(GENES_CSV):
    if os.path.exists("/content/Genes.csv"):
        GENES_CSV = "/content/Genes.csv"
    elif os.path.exists("/content/drive/MyDrive/Genes.csv"):
        GENES_CSV = "/content/drive/MyDrive/Genes.csv"
print("Genes.csv path:", GENES_CSV)

CLINICAL_VARS = [
    "Adjuvant Chemo","Age","IS_MALE",
    "Stage_IA","Stage_IB","Stage_II","Stage_III",
    "Histology_Adenocarcinoma","Histology_Large Cell Carcinoma","Histology_Squamous Cell Carcinoma",
    "Race_African American","Race_Asian","Race_Caucasian","Race_Native Hawaiian or Other Pacific Islander","Race_Unknown",
    "Smoked?_No","Smoked?_Unknown","Smoked?_Yes"
]

def load_genes_list(genes_csv):
    g = pd.read_csv(genes_csv)
    if "Prop" not in g.columns or "Gene" not in g.columns:
        raise ValueError(f"Genes.csv must have columns 'Gene' and 'Prop'. Found: {list(g.columns)}")
    g["Prop"] = pd.to_numeric(g["Prop"], errors="coerce").fillna(0)
    genes = g.loc[g["Prop"] == 1, "Gene"].astype(str).tolist()
    print(f"[Genes] Selected {len(genes)} genes with Prop == 1")
    return genes

def coerce_survival_cols(df):
    if df["OS_STATUS"].dtype == object:
        df["OS_STATUS"] = df["OS_STATUS"].replace({"DECEASED":1,"LIVING":0,"Dead":1,"Alive":0}).astype(int)
    else:
        df["OS_STATUS"] = pd.to_numeric(df["OS_STATUS"], errors="coerce").fillna(0).astype(int)
    df["OS_MONTHS"] = pd.to_numeric(df["OS_MONTHS"], errors="coerce").fillna(0.0).astype(float)
    return df

def preprocess_split(df, clinical_vars, gene_names):
    if "Adjuvant Chemo" in df.columns:
        df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
    for col in ["Adjuvant Chemo","IS_MALE"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int)
    df = coerce_survival_cols(df)
    keep_cols = [c for c in clinical_vars if c in df.columns] + [g for g in gene_names if g in df.columns]
    missing_clin = [c for c in clinical_vars if c not in df.columns]
    if missing_clin:
        print(f"[WARN] Missing clinical columns: {missing_clin}")
    cols = ["OS_STATUS","OS_MONTHS"] + keep_cols
    return df[cols].copy()

# Load CSVs
train_raw = pd.read_csv(TRAIN_CSV)
valid_raw = pd.read_csv(VALID_CSV)
test_raw  = pd.read_csv(TEST_CSV)

GENE_LIST = load_genes_list(GENES_CSV)

# Reduce to requested columns on each split
train_df = preprocess_split(train_raw, CLINICAL_VARS, GENE_LIST)
valid_df = preprocess_split(valid_raw, CLINICAL_VARS, GENE_LIST)
test_df  = preprocess_split(test_raw,  CLINICAL_VARS, GENE_LIST)

# Intersect features that exist everywhere
feat_candidates = [c for c in (CLINICAL_VARS + GENE_LIST)
                   if c in train_df.columns and c in valid_df.columns and c in test_df.columns]
CLIN_FEATS = [c for c in CLINICAL_VARS if c in feat_candidates]
GENE_FEATS = [g for g in GENE_LIST if g in feat_candidates]
CLIN_FEATS_PRETX = [c for c in CLIN_FEATS if c != "Adjuvant Chemo"]  # PS excludes treatment itself
print(f"[Features] Using {len(feat_candidates)} common features → Clinical={len(CLIN_FEATS)}, Genes={len(GENE_FEATS)}")

# Sort Train/Val by event time & status (desc)
train_df = train_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)
valid_df = valid_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

# ---- Train-only gene ranking (univariate Cox CI on TRAIN) ----
def rank_genes_univariate(train_df, gene_cols):
    y = Surv.from_arrays(event=train_df["OS_STATUS"].astype(bool).values,
                         time=train_df["OS_MONTHS"].values.astype(float))
    ranks = []
    for g in gene_cols:
        Xg = train_df[[g]].to_numpy(dtype=np.float32)
        try:
            model = CoxPHSurvivalAnalysis(alpha=1e-12)
            model.fit(Xg, y)
            pred = model.predict(Xg)
            ci = concordance_index_censored(y["event"], y["time"], pred)[0]
            ranks.append((g, float(ci)))
        except Exception:
            ranks.append((g, 0.5))
    ranks.sort(key=lambda z: z[1], reverse=True)
    return [g for g, _ in ranks]

GENE_RANK = rank_genes_univariate(train_df, GENE_FEATS)
MAX_GENES = len(GENE_RANK)
print(f"[Gene Ranking] Ranked {MAX_GENES} genes on TRAIN")

# ============================================================
# Feature construction (main effects + interactions) & IPTW
# ============================================================
def build_features_with_interactions(df, main_genes, inter_genes, act_col="Adjuvant Chemo"):
    base_cols = CLIN_FEATS + list(main_genes)  # keep ACT main effect via CLIN_FEATS
    X_base = df[base_cols].to_numpy(dtype=np.float32)
    A = df[act_col].to_numpy(dtype=np.float32).reshape(-1, 1)
    if len(inter_genes) > 0:
        X_int = df[list(inter_genes)].to_numpy(dtype=np.float32) * A
        X = np.concatenate([X_base, X_int], axis=1)
        names = base_cols + [f"{g}*ACT" for g in inter_genes]
    else:
        X = X_base
        names = base_cols
    return X, names

def compute_iptw(df, covariate_cols, act_col="Adjuvant Chemo",
                 ps_clip=(0.05, 0.95), w_clip=(0.1, 10.0),
                 ref_prevalence=None, model=None):
    A = df[act_col].astype(int).values
    X = df[covariate_cols].astype(float).values
    if model is None:
        model = LogisticRegression(max_iter=2000, solver="lbfgs", class_weight="balanced")
        model.fit(X, A)
    ps = model.predict_proba(X)[:, 1]
    ps = np.clip(ps, ps_clip[0], ps_clip[1])
    if ref_prevalence is None:
        ref_prevalence = A.mean()
    w = np.where(A == 1, ref_prevalence / ps, (1 - ref_prevalence) / (1 - ps))
    w = np.clip(w, w_clip[0], w_clip[1])
    return w.astype(np.float32), model, float(ref_prevalence)

# ============================================================
# Training helpers (with input dropout/noise, weights, warm-ups)
# ============================================================
def _apply_input_dropout(x, p):
    if p <= 0.0: return x
    keep = 1.0 - p
    mask = torch.bernoulli(torch.full_like(x, keep))
    return x * mask / max(keep, 1e-6)

def train_one_epoch_reg(model, optimizer, dataloader, device,
                        l1_lambda=0.0, epoch=0, warmup_epochs=20,
                        input_dropout=0.0, noise_std=0.0, grad_clip=5.0):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))
    loss_sum, w_sum = 0.0, 0.0
    for x, t, e, w in dataloader:
        if e.sum().item() == 0:
            continue
        x, t, e, w = x.to(device), t.to(device), e.to(device), w.to(device)
        if input_dropout > 0.0: x = _apply_input_dropout(x, input_dropout)
        if noise_std > 0.0: x = x + noise_std * torch.randn_like(x)
        optimizer.zero_grad(set_to_none=True)
        out = torch.clamp(model(x), -20, 20)
        loss = cox_loss(out, e, t, weight=w)
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
        optimizer.step()
        loss_sum += loss.item() * float(w.sum().item())
        w_sum += float(w.sum().item())
    return {'avg_loss': (loss_sum / max(w_sum, 1e-9)), 'warm': warm}

def full_risk_set_step_reg(model, optimizer, ds, device,
                           l1_lambda=0.0, warm=1.0,
                           input_dropout=0.0, noise_std=0.0, grad_clip=5.0):
    model.train()
    X_all = ds.x.to(device); t_all = ds.time.to(device); e_all = ds.event.to(device); w_all = ds.w.to(device)
    XX = X_all
    if input_dropout > 0.0: XX = _apply_input_dropout(XX, input_dropout)
    if noise_std > 0.0: XX = XX + noise_std * torch.randn_like(XX)
    optimizer.zero_grad(set_to_none=True)
    out_all = torch.clamp(model(XX), -20, 20)
    loss_full = cox_loss(out_all, e_all, t_all, weight=w_all)
    if l1_lambda > 0:
        loss_full = loss_full + (l1_lambda * warm) * l1_penalty_first_layer(model)
    loss_full.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
    optimizer.step()
    return float(loss_full.detach().cpu().item())

def count_params(in_dim, layers):
    params, d = 0, in_dim
    for h in layers:
        params += d*h + h
        d = h
    params += d*1 + 1
    return int(params)

# ============================================================
# Optuna: Multi-Objective (Val CI ↑, Train–Val GAP ↓, Params ↓)
# + Feature budget (Top-K genes & Top-K interactions)
# ============================================================
ARCH_CHOICES = (
    "16","32","64","128","256","512",
    "32-16","64-32","64-64","128-64","128-128","256-128","256-256","512-256",
    "64-32-16","128-64-32","256-128-64"
)
def layers_from_arch(arch_str: str):
    return [int(x) for x in arch_str.split("-") if x.strip()]

# ---------- STATIC SPACES (no dynamic value space); clamp later ----------
def suggest_hparams(trial):
    # Fixed main-gene choices (dedup + bounded by MAX_GENES)
    base_main = [32, 64, 128, 256, 512, 800, MAX_GENES]
    TOPK_MAIN_CHOICES = tuple(sorted({k for k in base_main if k <= MAX_GENES}))
    top_k_genes = trial.suggest_categorical("top_k_genes", TOPK_MAIN_CHOICES)

    # Fixed interaction choices; NOT conditioned on top_k_genes
    base_inter = [0, 16, 32, 64, 128, 256, 512]
    TOPK_INTER_CHOICES = tuple(sorted({k for k in base_inter if k <= MAX_GENES}))
    top_k_inter_raw = trial.suggest_categorical("top_k_inter", TOPK_INTER_CHOICES)

    # net & regularizers
    arch = trial.suggest_categorical("arch", ARCH_CHOICES)
    dropout = trial.suggest_float("dropout", 0.10, 0.70)
    input_dropout = trial.suggest_float("input_dropout", 0.00, 0.30)
    noise_std = trial.suggest_float("noise_std", 0.0, 0.10)

    wd = trial.suggest_float("wd", 1e-6, 1e-1, log=True)
    apply_final_wd = trial.suggest_categorical("apply_final_wd", (0, 1))
    use_l1 = trial.suggest_categorical("use_l1", (0, 1))
    l1 = 0.0 if use_l1 == 0 else trial.suggest_float("l1", 1e-8, 3e-3, log=True)

    lr = trial.suggest_float("lr", 1e-5, 5e-4, log=True)
    sched = trial.suggest_categorical("sched", ("cosine", "cawr", "none"))
    if sched == "cawr":
        cawr_T0 = trial.suggest_int("cawr_T0", 16, 80, step=8)
        cawr_Tmult = trial.suggest_categorical("cawr_Tmult", (1, 2, 3))
    else:
        cawr_T0 = cawr_Tmult = None

    epochs = trial.suggest_int("epochs", 64, 384, step=32)
    batch_size = trial.suggest_categorical("batch_size", (32, 64, 128))
    grad_clip = trial.suggest_float("grad_clip", 1.0, 10.0)

    # Clamp interactions to <= main set (no dynamic space; enforced post-hoc)
    top_k_inter = int(min(top_k_inter_raw, top_k_genes))

    return dict(
        arch=arch, top_k_genes=int(top_k_genes), top_k_inter=int(top_k_inter),
        dropout=dropout, input_dropout=input_dropout, noise_std=noise_std,
        wd=wd, apply_final_wd=int(apply_final_wd), use_l1=int(use_l1), l1=float(l1),
        lr=lr, sched=sched, cawr_T0=cawr_T0, cawr_Tmult=cawr_Tmult,
        epochs=int(epochs), batch_size=int(batch_size), grad_clip=float(grad_clip)
    )

# Warm-ups and caps
MAX_EPOCHS_CAP = 384
WARMUP_EPOCHS_L1 = 30
WARMUP_EPOCHS_DROPOUT = 30
WARMUP_EPOCHS_WD = 30
DROPOUT_START = 0.15
WD_START = 0.0

def objective(trial):
    hp = suggest_hparams(trial)
    layers = layers_from_arch(hp["arch"])

    # --- Trial-specific features (main + interactions)
    genes_main = GENE_RANK[:hp["top_k_genes"]]
    # interactions must be a subset of the selected mains
    genes_inter = genes_main[:hp["top_k_inter"]]

    Xtr_raw, feat_names = build_features_with_interactions(train_df, genes_main, genes_inter, act_col="Adjuvant Chemo")
    Xva_raw, _          = build_features_with_interactions(valid_df, genes_main, genes_inter, act_col="Adjuvant Chemo")

    # --- TRAIN-only impute & standardize
    med = np.nanmedian(Xtr_raw, axis=0)
    Xtr = np.where(np.isnan(Xtr_raw), med, Xtr_raw)
    Xva = np.where(np.isnan(Xva_raw), med, Xva_raw)
    sc = StandardScaler().fit(Xtr)
    Xtr = sc.transform(Xtr).astype(np.float32)
    Xva = sc.transform(Xva).astype(np.float32)

    ytr_t = train_df["OS_MONTHS"].to_numpy(np.float32)
    ytr_e = train_df["OS_STATUS"].to_numpy(int)
    yva_t = valid_df["OS_MONTHS"].to_numpy(np.float32)
    yva_e = valid_df["OS_STATUS"].to_numpy(int)

    # --- IPTW on clinical pre-treatment covariates ONLY (exclude ACT)
    w_tr, ps_model, pi_tr = compute_iptw(train_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo")
    w_va, _, _ = compute_iptw(valid_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo",
                              ref_prevalence=pi_tr, model=ps_model)

    # --- Datasets/loaders
    bs = hp["batch_size"]
    tr_ds = SurvivalDataset(Xtr, ytr_t, ytr_e, weights=w_tr)
    va_ds = SurvivalDataset(Xva, yva_t, yva_e, weights=w_va)
    tr_sampler = EventBalancedBatchSampler(ytr_e, bs, seed=42)
    tr_loader = DataLoader(tr_ds, batch_sampler=tr_sampler, num_workers=0)
    tr_eval_loader = DataLoader(tr_ds, batch_size=bs, shuffle=False, num_workers=0)
    va_loader = DataLoader(va_ds, batch_size=bs, shuffle=False, num_workers=0)

    # --- Model/opt/sched
    in_dim_trial = Xtr.shape[1]
    model = DeepSurvMLP(in_dim_trial, layers, dropout=hp["dropout"]).to(device)
    opt = make_optimizer_groups(model, lr=hp["lr"], wd=hp["wd"], apply_final_wd=bool(hp["apply_final_wd"]))

    epochs = int(min(hp["epochs"], MAX_EPOCHS_CAP))
    if hp["sched"] == "cosine":
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
        def sched_step(i): sched.step()
    elif hp["sched"] == "cawr":
        sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, T_0=int(hp["cawr_T0"]), T_mult=int(hp["cawr_Tmult"]))
        def sched_step(i): sched.step(i+1)
    else:
        sched = None
        def sched_step(i): pass

    # --- Manual early-stopping on Val CI (MO: no trial.report/prune)
    PATIENCE = 40
    MIN_DELTA = 1e-4
    no_improve = 0
    best_val_ci = -np.inf
    best_tr_ci_at_best = float("nan")
    best_epoch = 0

    for ep in range(epochs):
        # warm-ups
        frac_d = min(1.0, ep / float(WARMUP_EPOCHS_DROPOUT))
        frac_w = min(1.0, ep / float(WARMUP_EPOCHS_WD))
        set_dropout_p(model, DROPOUT_START + (hp['dropout'] - DROPOUT_START) * frac_d)
        set_weight_decay(opt, WD_START + (hp['wd'] - WD_START) * frac_w)

        # one epoch + full-risk-set correction (with regularizers + weights)
        st = train_one_epoch_reg(
            model, opt, tr_loader, device,
            l1_lambda=float(hp.get("l1", 0.0)), epoch=ep, warmup_epochs=WARMUP_EPOCHS_L1,
            input_dropout=float(hp.get("input_dropout", 0.0)),
            noise_std=float(hp.get("noise_std", 0.0)),
            grad_clip=float(hp.get("grad_clip", 5.0))
        )
        _ = full_risk_set_step_reg(
            model, opt, tr_ds, device,
            l1_lambda=float(hp.get("l1", 0.0)), warm=st['warm'],
            input_dropout=float(hp.get("input_dropout", 0.0)),
            noise_std=float(hp.get("noise_std", 0.0)),
            grad_clip=float(hp.get("grad_clip", 5.0))
        )
        sched_step(ep)

        # eval both splits so we can compute the gap at the best Val epoch
        va_ci = evaluate_ci(model, va_loader, device)
        tr_ci = evaluate_ci(model, tr_eval_loader, device)

        if va_ci > best_val_ci + MIN_DELTA:
            best_val_ci = va_ci
            best_tr_ci_at_best = tr_ci
            best_epoch = ep + 1
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE:
                break

    gap = max(0.0, best_tr_ci_at_best - best_val_ci)
    param_cnt = count_params(in_dim_trial, layers)

    # annotate & cleanup
    trial.set_user_attr("best_epoch", int(best_epoch))
    trial.set_user_attr("n_features", int(in_dim_trial))
    del model, opt, sched
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    gc.collect()

    # Multi-objective return: (Val CI ↑, Gap ↓, Params ↓)
    return float(best_val_ci), float(gap), int(param_cnt)

# ---- Optuna Study: NSGA-II (no pruner for multi-objective) ----
storage = "sqlite:///deepsurv_optuna.db"
study_name = f"deepsurv_cox_mo_gap_size_interactions_static_M{MAX_GENES}"  # new name to avoid old dynamic-space schema

sampler = NSGAIISampler(seed=42, population_size=24)

study = optuna.create_study(
    directions=["maximize", "minimize", "minimize"],
    study_name=study_name,
    storage=storage,
    load_if_exists=True,
    sampler=sampler
)

# ---- Launch Optuna Dashboard (proxied URL printed) ----
try:
    from optuna_dashboard import run_server
    PORT = portpicker.pick_unused_port()
    def _start_dashboard():
        run_server(storage, host="0.0.0.0", port=PORT)  # no unsupported kwargs
    t = threading.Thread(target=_start_dashboard, daemon=True)
    t.start(); time.sleep(2)
    dash_url = output.eval_js(f"google.colab.kernel.proxyPort({PORT}, {{'cache': false}})")
    print("Optuna Dashboard:", dash_url)
except Exception as ex:
    print("[Optuna Dashboard] Could not start dashboard automatically.", ex)
    print("Run locally:  optuna-dashboard sqlite:///deepsurv_optuna.db")

# ---- Optimize ----
N_TRIALS = 100
print(f"Starting multi-objective optimization: {N_TRIALS} trials")
study.optimize(objective, n_trials=N_TRIALS, gc_after_trial=True)

# ---- Choose a robust solution from Pareto front ----
pareto = study.best_trials
best_val = max(tr.values[0] for tr in pareto)
TOL = 0.005  # within 0.5% absolute C-index of best
cands = [tr for tr in pareto if (best_val - tr.values[0]) <= TOL]
cands.sort(key=lambda tr: (tr.values[1], tr.values[2]))  # gap, then params
chosen = cands[0]
print("\n[Chosen Pareto] Val CI=%.4f | Gap=%.4f | Params=%d" % (chosen.values[0], chosen.values[1], chosen.values[2]))
print("[Chosen Params]", chosen.params)

# ============================================================
# Final training on Train+Val with chosen hyperparams + interactions + IPTW
# ============================================================
best_hp = chosen.params
best_layers = layers_from_arch(best_hp["arch"])
k_main = int(best_hp["top_k_genes"])
k_int  = int(best_hp["top_k_inter"])
genes_main = GENE_RANK[:k_main]
# interactions must be a subset of the selected main genes:
genes_inter = genes_main[:k_int]

print(f"[Final] Using features: {len(CLIN_FEATS)} clinical + {k_main} genes (main) + {k_int} interactions")

# Build Train+Val & Test features (same construction)
trainval_df = pd.concat([train_df, valid_df], axis=0, ignore_index=True)
trainval_df = trainval_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

X_trv_raw, feat_names = build_features_with_interactions(trainval_df, genes_main, genes_inter, act_col="Adjuvant Chemo")
X_te_raw,  _          = build_features_with_interactions(test_df,      genes_main, genes_inter, act_col="Adjuvant Chemo")

# Impute + standardize on Train+Val; apply to Test
med_trv = np.nanmedian(X_trv_raw, axis=0)
X_trv = np.where(np.isnan(X_trv_raw), med_trv, X_trv_raw)
X_te  = np.where(np.isnan(X_te_raw),  med_trv, X_te_raw)
sc_trv = StandardScaler().fit(X_trv)
X_trv = sc_trv.transform(X_trv).astype(np.float32)
X_te  = sc_trv.transform(X_te).astype(np.float32)

y_trv_t = trainval_df["OS_MONTHS"].to_numpy(np.float32)
y_trv_e = trainval_df["OS_STATUS"].to_numpy(int)
y_te_t  = test_df["OS_MONTHS"].to_numpy(np.float32)
y_te_e  = test_df["OS_STATUS"].to_numpy(int)

# IPTW on Train+Val, apply to Test with same prevalence
w_trv, ps_model_fin, pi_fin = compute_iptw(trainval_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo")
w_te, _, _ = compute_iptw(test_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo",
                          ref_prevalence=pi_fin, model=ps_model_fin)

# Loaders
bs_fin = int(best_hp["batch_size"])
ds_trv = SurvivalDataset(X_trv, y_trv_t, y_trv_e, weights=w_trv)
ds_te  = SurvivalDataset(X_te,  y_te_t,  y_te_e,  weights=w_te)
sam_trv = EventBalancedBatchSampler(y_trv_e, bs_fin, seed=7)
dl_trv  = DataLoader(ds_trv, batch_sampler=sam_trv, num_workers=0)
dl_trv_eval = DataLoader(ds_trv, batch_size=bs_fin, shuffle=False, num_workers=0)
dl_te   = DataLoader(ds_te,  batch_size=bs_fin, shuffle=False, num_workers=0)

# Model / optimizer / scheduler
in_dim_final = X_trv.shape[1]
model_final = DeepSurvMLP(in_dim_final, best_layers, dropout=float(best_hp["dropout"])).to(device)
opt_final = make_optimizer_groups(model_final, lr=float(best_hp["lr"]), wd=float(best_hp["wd"]),
                                  apply_final_wd=bool(best_hp["apply_final_wd"]))
epochs_fin = int(min(best_hp["epochs"], MAX_EPOCHS_CAP))
if best_hp["sched"] == "cosine":
    sched_final = torch.optim.lr_scheduler.CosineAnnealingLR(opt_final, T_max=epochs_fin)
    def sched_step(i): sched_final.step()
elif best_hp["sched"] == "cawr":
    sched_final = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        opt_final, T_0=int(best_hp["cawr_T0"]), T_mult=int(best_hp["cawr_Tmult"]))
    def sched_step(i): sched_final.step(i+1)
else:
    sched_final = None
    def sched_step(i): pass

# Train
for ep in range(epochs_fin):
    frac_d = min(1.0, ep / float(WARMUP_EPOCHS_DROPOUT))
    frac_w = min(1.0, ep / float(WARMUP_EPOCHS_WD))
    set_dropout_p(model_final, DROPOUT_START + (float(best_hp['dropout']) - DROPOUT_START) * frac_d)
    set_weight_decay(opt_final, WD_START + (float(best_hp['wd']) - WD_START) * frac_w)

    st = train_one_epoch_reg(
        model_final, opt_final, dl_trv, device,
        l1_lambda=float(best_hp.get("l1", 0.0)), epoch=ep, warmup_epochs=WARMUP_EPOCHS_L1,
        input_dropout=float(best_hp.get("input_dropout", 0.0)),
        noise_std=float(best_hp.get("noise_std", 0.0)),
        grad_clip=float(best_hp.get("grad_clip", 5.0))
    )
    _ = full_risk_set_step_reg(
        model_final, opt_final, ds_trv, device,
        l1_lambda=float(best_hp.get("l1", 0.0)), warm=st['warm'],
        input_dropout=float(best_hp.get("input_dropout", 0.0)),
        noise_std=float(best_hp.get("noise_std", 0.0)),
        grad_clip=float(best_hp.get("grad_clip", 5.0))
    )
    sched_step(ep)

# Evaluate
trainval_ci = evaluate_ci(model_final, dl_trv_eval, device)
test_ci = evaluate_ci(model_final, dl_te, device)
print(f"\n[Final] Train+Val CI: {trainval_ci:.4f}")
print(f"[Final] Test CI:      {test_ci:.4f}")

# Per-arm C-indices (helpful for treatment recommendation sanity checks)
act_trv = trainval_df["Adjuvant Chemo"].to_numpy(int)
act_te  = test_df["Adjuvant Chemo"].to_numpy(int)
trv_grouped = evaluate_ci_grouped(model_final, X_trv, y_trv_t, y_trv_e, act_trv == 1)
te_grouped  = evaluate_ci_grouped(model_final, X_te,  y_te_t,  y_te_e,  act_te == 1)
print("[Train+Val] CI by arm:", trv_grouped)
print("[Test]      CI by arm:", te_grouped)

# Save artifacts
OUT_DIR = "/content/drive/MyDrive/deepsurv_results_optuna_interactions_iptw"
os.makedirs(OUT_DIR, exist_ok=True)
torch.save(model_final.state_dict(), os.path.join(OUT_DIR, "deepsurv_best.pt"))
with open(os.path.join(OUT_DIR, "chosen_params.txt"), "w") as f:
    f.write(str(best_hp))
with open(os.path.join(OUT_DIR, "features_used.txt"), "w") as f:
    f.write("\n".join(feat_names))
print("Saved final model and parameters to:", OUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
Genes.csv path: /content/drive/MyDrive/Genes.csv
[Genes] Selected 1555 genes with Prop == 1
[Features] Using 1573 common features → Clinical=18, Genes=1555


  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
[I 2025-10-19 19:57:30,215] A new study created in RDB with name: deepsurv_cox_mo_gap_size_interactions_static_M1555


[Gene Ranking] Ranked 1555 genes on TRAIN


Bottle v0.13.4 server starting up (using WSGIRefServer())...
Listening on http://0.0.0.0:36243/
Hit Ctrl-C to quit.



Optuna Dashboard: https://36243-gpu-a100-s-2cm35yjghnl2l-b.asia-southeast1-0.prod.colab.dev
Starting multi-objective optimization: 100 trials


[I 2025-10-19 19:59:10,153] Trial 0 finished with values: [0.6152409964133052, 0.08393829126825181, 70657.0] and parameters: {'top_k_genes': 64, 'top_k_inter': 128, 'arch': '256-128', 'dropout': 0.20231447421237492, 'input_dropout': 0.019515477895583853, 'noise_std': 0.09488855372533334, 'wd': 0.0673224892077534, 'apply_final_wd': 0, 'use_l1': 1, 'l1': 2.5749486860432668e-06, 'lr': 1.6119044727609182e-05, 'sched': 'none', 'epochs': 128, 'batch_size': 32, 'grad_clip': 5.920392514089517}.
[I 2025-10-19 20:00:11,063] Trial 1 finished with values: [0.6322409472804992, 0.06179018547983173, 5249.0] and parameters: {'top_k_genes': 64, 'top_k_inter': 512, 'arch': '32-16', 'dropout': 0.6178620555253561, 'input_dropout': 0.18698943804826737, 'noise_std': 0.03308980248526492, 'wd': 2.078699690689779e-06, 'apply_final_wd': 1, 'use_l1': 0, 'lr': 0.0003216235469207422, 'sched': 'none', 'epochs': 320, 'batch_size': 64, 'grad_clip': 5.704595464437946}.
[I 2025-10-19 20:01:49,567] Trial 2 finished with


[Chosen Pareto] Val CI=0.6805 | Gap=0.1777 | Params=813057
[Chosen Params] {'top_k_genes': 800, 'top_k_inter': 512, 'arch': '512-256', 'dropout': 0.2002251447405699, 'input_dropout': 0.1052745037656236, 'noise_std': 0.011775108289014114, 'wd': 0.00018737227240512905, 'apply_final_wd': 0, 'use_l1': 0, 'lr': 0.0001552120833503202, 'sched': 'cosine', 'epochs': 256, 'batch_size': 128, 'grad_clip': 4.180170052234475}
[Final] Using features: 18 clinical + 800 genes (main) + 512 interactions

[Final] Train+Val CI: 0.8692
[Final] Test CI:      0.6233
[Train+Val] CI by arm: {'ACT=1': 0.6878566537320246, 'ACT=0': 0.9276072855401425}
[Test]      CI by arm: {'ACT=1': 0.5584642233856894, 'ACT=0': 0.6499927818680525}
Saved final model and parameters to: /content/drive/MyDrive/deepsurv_results_optuna_interactions_iptw


In [3]:
# ============================================================
# Colab-ready SINGLE CELL
# DeepSurv + Multi-Objective Optuna (Val CI ↑, Gap ↓, Params ↓)
# with capacity budgets, conservative Pareto selector, column-dropout,
# grouped L1 (stronger on interactions), and WD applied to last layer.
# ============================================================

# ---------- Installs (Colab) ----------
!pip -q install optuna optuna-dashboard scikit-survival portpicker

# ---------- (Optional) Mount Google Drive ----------
from google.colab import drive
drive.mount('/content/drive')

# ---------- Imports ----------
import os, math, copy, warnings, random, gc, time, threading
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler

# Prefer torchsurv if available; fallback to custom Efron
try:
    from torchsurv.loss.cox import neg_partial_log_likelihood
    _HAS_TORCHSURV = True
except Exception:
    _HAS_TORCHSURV = False

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv
from sksurv.linear_model import CoxPHSurvivalAnalysis

import optuna
from optuna.samplers import NSGAIISampler

import portpicker
from google.colab import output

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ============================================================
# Cox Losses
# ============================================================
def _cox_negloglik_efron(pred, event, time):
    eta = pred.reshape(-1)
    e = event.to(torch.float32).reshape(-1)
    t = time.reshape(-1)
    order = torch.argsort(t, descending=True)
    t = t[order]; e = e[order]; eta = eta[order]
    exp_eta = torch.exp(eta)
    cum_exp = torch.cumsum(exp_eta, dim=0)
    uniq_mask = torch.ones_like(t, dtype=torch.bool)
    uniq_mask[1:] = t[1:] != t[:-1]
    idxs = torch.nonzero(uniq_mask, as_tuple=False).reshape(-1)
    idxs = torch.cat([idxs, torch.tensor([len(t)], device=t.device)])
    nll = torch.tensor(0.0, device=t.device)
    for k in range(len(idxs)-1):
        start, end = idxs[k].item(), idxs[k+1].item()
        e_slice = e[start:end]
        d = int(e_slice.sum().item())
        if d == 0: continue
        eta_events = eta[start:end][e_slice.bool()]
        exp_events = torch.exp(eta[start:end])[e_slice.bool()]
        s_eta = eta_events.sum()
        risk_sum = cum_exp[end-1]
        s_exp = exp_events.sum()
        eps = 1e-12
        log_terms = 0.0
        for j in range(d):
            log_terms = log_terms + torch.log(risk_sum - (j / d) * s_exp + eps)
        nll = nll - (s_eta - log_terms)
    return nll / t.numel()

def cox_negloglik(pred, event, time):
    if _HAS_TORCHSURV:
        return neg_partial_log_likelihood(pred, event, time, reduction='mean')
    return _cox_negloglik_efron(pred, event, time)

# Weighted Breslow negative partial log-likelihood (IPTW)
def cox_negloglik_breslow_weighted(pred, event, time, weight=None):
    eta = pred.reshape(-1)
    e = event.to(torch.float32).reshape(-1)
    t = time.reshape(-1)
    if weight is None:
        weight = torch.ones_like(eta)
    else:
        weight = weight.reshape(-1)

    order = torch.argsort(t, descending=True)
    t = t[order]; e = e[order]; eta = eta[order]; w = weight[order]
    exp_eta = torch.exp(eta)
    w_exp = w * exp_eta
    cum_w_exp = torch.cumsum(w_exp, dim=0)

    uniq = torch.ones_like(t, dtype=torch.bool)
    uniq[1:] = t[1:] != t[:-1]
    idxs = torch.nonzero(uniq, as_tuple=False).reshape(-1)
    idxs = torch.cat([idxs, torch.tensor([len(t)], device=t.device)])

    nll = torch.tensor(0.0, device=t.device)
    eps = 1e-12
    for k in range(len(idxs)-1):
        s, eidx = idxs[k].item(), idxs[k+1].item()
        emask = e[s:eidx] > 0.5
        if emask.sum() == 0:
            continue
        w_events = w[s:eidx][emask]
        eta_events = eta[s:eidx][emask]
        s_eta = (w_events * eta_events).sum()
        denom = cum_w_exp[eidx-1]
        nll = nll - (s_eta - w_events.sum() * torch.log(denom + eps))
    return nll / (w.sum() + 1e-9)

def cox_loss(pred, event, time, weight=None):
    """Use weighted Breslow if weights provided; else Efron/torchsurv."""
    if weight is None:
        return cox_negloglik(pred, event, time)
    return cox_negloglik_breslow_weighted(pred, event, time, weight)

# ============================================================
# Model, Dataset (with weights), Sampler, Utils
# ============================================================
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events, weights=None):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events.astype(bool), dtype=torch.bool)
        if weights is None:
            self.w = torch.ones(len(self.x), dtype=torch.float32)
        else:
            self.w = torch.tensor(weights, dtype=torch.float32)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx], self.w[idx]

class EventBalancedBatchSampler(Sampler):
    def __init__(self, events_numpy, batch_size, seed=0):
        events = np.asarray(events_numpy).astype(bool)
        self.pos_idx = np.where(events)[0]
        self.neg_idx = np.where(~events)[0]
        assert len(self.pos_idx) > 0, "No events in training set — cannot balance batches."
        self.bs = int(batch_size)
        self.rng = np.random.default_rng(seed)
    def __iter__(self):
        pos = self.rng.permutation(self.pos_idx)
        neg = self.rng.permutation(self.neg_idx)
        n_total = len(pos) + len(neg)
        n_batches = math.ceil(n_total / self.bs)
        pi = ni = 0
        for _ in range(n_batches):
            take_pos = 1 if pi < len(pos) else 0
            avail_neg = max(0, len(neg) - ni)
            take_neg = min(self.bs - take_pos, avail_neg)
            need = self.bs - (take_pos + take_neg)
            extra_pos = min(need, max(0, len(pos) - (pi + take_pos)))
            take_pos += extra_pos
            batch = np.concatenate([pos[pi:pi+take_pos], neg[ni:ni+take_neg]])
            pi += take_pos; ni += take_neg
            if batch.size == 0: break
            self.rng.shuffle(batch)
            yield batch.tolist()
    def __len__(self):
        return math.ceil((len(self.pos_idx) + len(self.neg_idx)) / self.bs)

def make_optimizer_groups(model, lr, wd, apply_final_wd=True):
    # apply_final_wd now defaults to True
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        if name.endswith('bias'):
            no_decay.append(p); continue
        if (last_linear is not None) and (p is last_linear.weight) and not apply_final_wd:
            no_decay.append(p); continue
        decay.append(p)
    return optim.AdamW(
        [{'params': decay, 'weight_decay': float(wd)},
         {'params': no_decay, 'weight_decay': 0.0}],
        lr=float(lr)
    )

def set_dropout_p(model, p):
    for m in model.modules():
        if isinstance(m, nn.Dropout): m.p = float(p)

def set_weight_decay(optimizer, wd):
    for g in optimizer.param_groups: g['weight_decay'] = float(wd)

def l1_penalty_first_layer(model):
    for m in model.modules():
        if isinstance(m, nn.Linear): return m.weight.abs().sum()
    return torch.tensor(0.0, device=next(model.parameters()).device)

# ---- NEW: grouped L1 that shrinks interaction columns more than mains
def l1_first_layer_grouped(model, inter_start_idx, lam_main=0.0, lam_inter=0.0):
    W = None
    for m in model.modules():
        if isinstance(m, nn.Linear):
            W = m.weight   # [hidden, in_dim]
            break
    if W is None or (lam_main == 0 and lam_inter == 0):
        return torch.tensor(0.0, device=next(model.parameters()).device)
    main_sum = W[:, :inter_start_idx].abs().sum()
    inter_sum = W[:, inter_start_idx:].abs().sum()
    return lam_main * main_sum + lam_inter * inter_sum

# ---- NEW: column-wise dropout (drop entire gene/interaction columns)
def column_dropout(x, p, col_start):
    """Drop whole columns from col_start onward with prob p (shared across batch)."""
    if p <= 0.0: return x
    keep = 1.0 - float(p)
    B, D = x.shape
    device = x.device
    if col_start >= D: return x
    # One Bernoulli per column, scale by 1/keep to keep expectation
    m = torch.bernoulli(torch.full((D - col_start,), keep, device=device)) / max(keep, 1e-6)
    x = x.clone()
    x[:, col_start:] = x[:, col_start:] * m
    return x

@torch.no_grad()
def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    for batch in dataloader:
        if len(batch) == 4:
            x, t, e, _ = batch
        else:
            x, t, e = batch
        y = torch.clamp(model(x.to(device)), -20, 20)
        preds.append(y.cpu().numpy().ravel())
        times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    times = np.concatenate(times); events = np.concatenate(events)
    if np.isnan(preds).any(): return -np.inf
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def evaluate_ci_grouped(model, X, t, e, group_mask):
    """Compute C-index within group_mask==1 and group_mask==0."""
    model.eval()
    with torch.no_grad():
        preds = model(torch.tensor(X, dtype=torch.float32, device=device)).cpu().numpy().ravel()
    res = {}
    for label, mask in [("ACT=1", group_mask.astype(bool)), ("ACT=0", ~group_mask.astype(bool))]:
        if mask.sum() >= 3:
            ci = concordance_index_censored(e[mask].astype(bool), t[mask], preds[mask])[0]
            res[label] = float(ci)
        else:
            res[label] = np.nan
    return res

# ============================================================
# Data loading & preprocessing
# ============================================================
TRAIN_CSV = "/content/drive/MyDrive/affyfRMATrain.csv"
VALID_CSV = "/content/drive/MyDrive/affyfRMAValidation.csv"
TEST_CSV  = "/content/drive/MyDrive/affyfRMATest.csv"

GENES_CSV = "/mnt/data/Genes.csv"
if not os.path.exists(GENES_CSV):
    if os.path.exists("/content/Genes.csv"):
        GENES_CSV = "/content/Genes.csv"
    elif os.path.exists("/content/drive/MyDrive/Genes.csv"):
        GENES_CSV = "/content/drive/MyDrive/Genes.csv"
print("Genes.csv path:", GENES_CSV)

CLINICAL_VARS = [
    "Adjuvant Chemo","Age","IS_MALE",
    "Stage_IA","Stage_IB","Stage_II","Stage_III",
    "Histology_Adenocarcinoma","Histology_Large Cell Carcinoma","Histology_Squamous Cell Carcinoma",
    "Race_African American","Race_Asian","Race_Caucasian","Race_Native Hawaiian or Other Pacific Islander","Race_Unknown",
    "Smoked?_No","Smoked?_Unknown","Smoked?_Yes"
]

def load_genes_list(genes_csv):
    g = pd.read_csv(genes_csv)
    if "Prop" not in g.columns or "Gene" not in g.columns:
        raise ValueError(f"Genes.csv must have columns 'Gene' and 'Prop'. Found: {list(g.columns)}")
    g["Prop"] = pd.to_numeric(g["Prop"], errors="coerce").fillna(0)
    genes = g.loc[g["Prop"] == 1, "Gene"].astype(str).tolist()
    print(f"[Genes] Selected {len(genes)} genes with Prop == 1")
    return genes

def coerce_survival_cols(df):
    if df["OS_STATUS"].dtype == object:
        df["OS_STATUS"] = df["OS_STATUS"].replace({"DECEASED":1,"LIVING":0,"Dead":1,"Alive":0}).astype(int)
    else:
        df["OS_STATUS"] = pd.to_numeric(df["OS_STATUS"], errors="coerce").fillna(0).astype(int)
    df["OS_MONTHS"] = pd.to_numeric(df["OS_MONTHS"], errors="coerce").fillna(0.0).astype(float)
    return df

def preprocess_split(df, clinical_vars, gene_names):
    if "Adjuvant Chemo" in df.columns:
        df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
    for col in ["Adjuvant Chemo","IS_MALE"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int)
    df = coerce_survival_cols(df)
    keep_cols = [c for c in clinical_vars if c in df.columns] + [g for g in gene_names if g in df.columns]
    missing_clin = [c for c in clinical_vars if c not in df.columns]
    if missing_clin:
        print(f"[WARN] Missing clinical columns: {missing_clin}")
    cols = ["OS_STATUS","OS_MONTHS"] + keep_cols
    return df[cols].copy()

# Load CSVs
train_raw = pd.read_csv(TRAIN_CSV)
valid_raw = pd.read_csv(VALID_CSV)
test_raw  = pd.read_csv(TEST_CSV)

GENE_LIST = load_genes_list(GENES_CSV)

# Reduce to requested columns on each split
train_df = preprocess_split(train_raw, CLINICAL_VARS, GENE_LIST)
valid_df = preprocess_split(valid_raw, CLINICAL_VARS, GENE_LIST)
test_df  = preprocess_split(test_raw,  CLINICAL_VARS, GENE_LIST)

# Intersect features that exist everywhere
feat_candidates = [c for c in (CLINICAL_VARS + GENE_LIST)
                   if c in train_df.columns and c in valid_df.columns and c in test_df.columns]
CLIN_FEATS = [c for c in CLINICAL_VARS if c in feat_candidates]
GENE_FEATS = [g for g in GENE_LIST if g in feat_candidates]
CLIN_FEATS_PRETX = [c for c in CLIN_FEATS if c != "Adjuvant Chemo"]  # PS excludes treatment itself
print(f"[Features] Using {len(feat_candidates)} common features → Clinical={len(CLIN_FEATS)}, Genes={len(GENE_FEATS)}")

# Sort Train/Val by event time & status (desc)
train_df = train_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)
valid_df = valid_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

# ---- Train-only gene ranking (univariate Cox CI on TRAIN) ----
def rank_genes_univariate(train_df, gene_cols):
    y = Surv.from_arrays(event=train_df["OS_STATUS"].astype(bool).values,
                         time=train_df["OS_MONTHS"].values.astype(float))
    ranks = []
    for g in gene_cols:
        Xg = train_df[[g]].to_numpy(dtype=np.float32)
        try:
            model = CoxPHSurvivalAnalysis(alpha=1e-12)
            model.fit(Xg, y)
            pred = model.predict(Xg)
            ci = concordance_index_censored(y["event"], y["time"], pred)[0]
            ranks.append((g, float(ci)))
        except Exception:
            ranks.append((g, 0.5))
    ranks.sort(key=lambda z: z[1], reverse=True)
    return [g for g, _ in ranks]

GENE_RANK = rank_genes_univariate(train_df, GENE_FEATS)
MAX_GENES = len(GENE_RANK)
print(f"[Gene Ranking] Ranked {MAX_GENES} genes on TRAIN")

# ============================================================
# Feature construction (main effects + interactions) & IPTW
# ============================================================
def build_features_with_interactions(df, main_genes, inter_genes, act_col="Adjuvant Chemo"):
    base_cols = CLIN_FEATS + list(main_genes)  # keep ACT main effect via CLIN_FEATS
    X_base = df[base_cols].to_numpy(dtype=np.float32)
    A = df[act_col].to_numpy(dtype=np.float32).reshape(-1, 1)
    if len(inter_genes) > 0:
        X_int = df[list(inter_genes)].to_numpy(dtype=np.float32) * A
        X = np.concatenate([X_base, X_int], axis=1)
        names = base_cols + [f"{g}*ACT" for g in inter_genes]
    else:
        X = X_base
        names = base_cols
    return X, names

def compute_iptw(df, covariate_cols, act_col="Adjuvant Chemo",
                 ps_clip=(0.05, 0.95), w_clip=(0.1, 10.0),
                 ref_prevalence=None, model=None):
    A = df[act_col].astype(int).values
    X = df[covariate_cols].astype(float).values
    if model is None:
        model = LogisticRegression(max_iter=2000, solver="lbfgs", class_weight="balanced")
        model.fit(X, A)
    ps = model.predict_proba(X)[:, 1]
    ps = np.clip(ps, ps_clip[0], ps_clip[1])
    if ref_prevalence is None:
        ref_prevalence = A.mean()
    w = np.where(A == 1, ref_prevalence / ps, (1 - ref_prevalence) / (1 - ps))
    w = np.clip(w, w_clip[0], w_clip[1])
    return w.astype(np.float32), model, float(ref_prevalence)

# ============================================================
# Training helpers (with input dropout/noise, weights, warm-ups)
# ============================================================
def _apply_input_dropout(x, p):
    if p <= 0.0: return x
    keep = 1.0 - p
    mask = torch.bernoulli(torch.full_like(x, keep))
    return x * mask / max(keep, 1e-6)

def train_one_epoch_reg(model, optimizer, dataloader, device,
                        l1_lambda=0.0, epoch=0, warmup_epochs=20,
                        input_dropout=0.0, noise_std=0.0, grad_clip=5.0,
                        col_dropout_p=0.0, col_start=0,
                        inter_start_idx=None, lam_main=0.0, lam_inter=0.0):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))
    loss_sum, w_sum = 0.0, 0.0
    for x, t, e, w in dataloader:
        if e.sum().item() == 0:
            continue
        x, t, e, w = x.to(device), t.to(device), e.to(device), w.to(device)
        if input_dropout > 0.0: x = _apply_input_dropout(x, input_dropout)
        if col_dropout_p > 0.0 and col_start > 0: x = column_dropout(x, col_dropout_p, col_start)
        if noise_std > 0.0: x = x + noise_std * torch.randn_like(x)
        optimizer.zero_grad(set_to_none=True)
        out = torch.clamp(model(x), -20, 20)
        loss = cox_loss(out, e, t, weight=w)
        if l1_lambda > 0:
            if inter_start_idx is None:
                loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
            else:
                loss = loss + (l1_lambda * warm) * l1_first_layer_grouped(model, inter_start_idx,
                                                                          lam_main=lam_main, lam_inter=lam_inter)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
        optimizer.step()
        loss_sum += loss.item() * float(w.sum().item())
        w_sum += float(w.sum().item())
    return {'avg_loss': (loss_sum / max(w_sum, 1e-9)), 'warm': warm}

def full_risk_set_step_reg(model, optimizer, ds, device,
                           l1_lambda=0.0, warm=1.0,
                           input_dropout=0.0, noise_std=0.0, grad_clip=5.0,
                           col_dropout_p=0.0, col_start=0,
                           inter_start_idx=None, lam_main=0.0, lam_inter=0.0):
    model.train()
    X_all = ds.x.to(device); t_all = ds.time.to(device); e_all = ds.event.to(device); w_all = ds.w.to(device)
    XX = X_all
    if input_dropout > 0.0: XX = _apply_input_dropout(XX, input_dropout)
    if col_dropout_p > 0.0 and col_start > 0: XX = column_dropout(XX, col_dropout_p, col_start)
    if noise_std > 0.0: XX = XX + noise_std * torch.randn_like(XX)
    optimizer.zero_grad(set_to_none=True)
    out_all = torch.clamp(model(XX), -20, 20)
    loss_full = cox_loss(out_all, e_all, t_all, weight=w_all)
    if l1_lambda > 0:
        if inter_start_idx is None:
            loss_full = loss_full + (l1_lambda * warm) * l1_penalty_first_layer(model)
        else:
            loss_full = loss_full + (l1_lambda * warm) * l1_first_layer_grouped(model, inter_start_idx,
                                                                                lam_main=lam_main, lam_inter=lam_inter)
    loss_full.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
    optimizer.step()
    return float(loss_full.detach().cpu().item())

def count_params(in_dim, layers):
    params, d = 0, in_dim
    for h in layers:
        params += d*h + h
        d = h
    params += d*1 + 1
    return int(params)

# ============================================================
# Optuna: Multi-Objective (Val CI ↑, Train–Val GAP ↓, Params ↓)
# + Budgets (feature + params), conservative Pareto selection
# ============================================================
# Smaller, conservative net choices (you can broaden later if needed)
ARCH_CHOICES = ("16","32","64","32-16","64-32","64-64")
def layers_from_arch(arch_str: str):
    return [int(x) for x in arch_str.split("-") if x.strip()]

def suggest_hparams(trial):
    # Static spaces; budgets will clamp per-trial
    base_main  = [16, 32, 64, 96, 128, 192, 256, 384, 512, 800, MAX_GENES]
    TOPK_MAIN_CHOICES = tuple(sorted({k for k in base_main if k <= MAX_GENES}))
    top_k_genes = trial.suggest_categorical("top_k_genes", TOPK_MAIN_CHOICES)

    base_inter = [0, 8, 16, 32, 64, 96, 128, 256, 512]
    TOPK_INTER_CHOICES = tuple(sorted({k for k in base_inter if k <= MAX_GENES}))
    top_k_inter_raw = trial.suggest_categorical("top_k_inter", TOPK_INTER_CHOICES)

    arch = trial.suggest_categorical("arch", ARCH_CHOICES)
    dropout = trial.suggest_float("dropout", 0.10, 0.50)
    input_dropout = trial.suggest_float("input_dropout", 0.00, 0.15)  # keep small; use column-dropout primarily
    noise_std = trial.suggest_float("noise_std", 0.0, 0.08)

    wd = trial.suggest_float("wd", 1e-6, 3e-3, log=True)
    use_l1 = trial.suggest_categorical("use_l1", (0, 1))
    l1 = 0.0 if use_l1 == 0 else trial.suggest_float("l1", 1e-7, 1e-3, log=True)

    lr = trial.suggest_float("lr", 1e-5, 4e-4, log=True)
    sched = trial.suggest_categorical("sched", ("cosine", "cawr", "none"))
    if sched == "cawr":
        cawr_T0 = trial.suggest_int("cawr_T0", 16, 72, step=8)
        cawr_Tmult = trial.suggest_categorical("cawr_Tmult", (1, 2, 3))
    else:
        cawr_T0 = cawr_Tmult = None

    epochs = trial.suggest_int("epochs", 96, 256, step=32)
    batch_size = trial.suggest_categorical("batch_size", (32, 64, 128))
    grad_clip = trial.suggest_float("grad_clip", 2.0, 9.0)

    # Clamp interactions to <= main set (static space; enforced again post-hoc)
    top_k_inter = int(min(top_k_inter_raw, top_k_genes))

    hp = dict(
        arch=arch, top_k_genes=int(top_k_genes), top_k_inter=int(top_k_inter),
        dropout=dropout, input_dropout=input_dropout, noise_std=noise_std,
        wd=wd, apply_final_wd=1,  # <-- force WD on last layer
        use_l1=int(use_l1), l1=float(l1),
        lr=lr, sched=sched, cawr_T0=cawr_T0, cawr_Tmult=cawr_Tmult,
        epochs=int(epochs), batch_size=int(batch_size), grad_clip=float(grad_clip)
    )
    return hp

# Warm-ups and caps
MAX_EPOCHS_CAP = 256
WARMUP_EPOCHS_L1 = 20
WARMUP_EPOCHS_DROPOUT = 20
WARMUP_EPOCHS_WD = 20
DROPOUT_START = 0.10
WD_START = 0.0

# ====== NEW: capacity budgets tied to sample size ======
N_EVENTS_TR = int(train_df["OS_STATUS"].sum())
FEAT_EVENT_FRACTION = 0.50         # tune 0.35–0.60
FEAT_BUDGET = max(24, int(FEAT_EVENT_FRACTION * N_EVENTS_TR))   # total inputs incl. clinical
PARAM_BUDGET = 120_000             # cap total parameters; tune 80k–150k
print(f"[Budgets] events(train)={N_EVENTS_TR} → feature budget≤{FEAT_BUDGET} total inputs, param budget≤{PARAM_BUDGET:,}")

# Column-dropout prob (genes + interactions). Keep moderate.
COL_DROPOUT_P = 0.30
# Grouped L1 strength split (interactions shrunk more than mains)
L1_LAM_MAIN = 0.2
L1_LAM_INTER = 1.0

def objective(trial):
    hp = suggest_hparams(trial)
    layers = layers_from_arch(hp["arch"])

    # --- Trial-specific features (main + interactions) with budgets
    MAX_NONCLIN = max(8, FEAT_BUDGET - len(CLIN_FEATS))
    k_main = int(min(hp["top_k_genes"], MAX_NONCLIN))
    k_int  = int(min(hp["top_k_inter"], k_main, MAX_NONCLIN - k_main))

    genes_main  = GENE_RANK[:k_main]
    genes_inter = genes_main[:k_int]

    Xtr_raw, feat_names = build_features_with_interactions(train_df, genes_main, genes_inter, act_col="Adjuvant Chemo")
    Xva_raw, _          = build_features_with_interactions(valid_df, genes_main, genes_inter, act_col="Adjuvant Chemo")

    # If params exceed cap, shrink interactions first, then mains
    in_dim_trial = Xtr_raw.shape[1]
    while count_params(in_dim_trial, layers) > PARAM_BUDGET and (k_int > 0 or k_main > 8):
        if k_int > 0:
            k_int = max(0, k_int // 2)
        else:
            k_main = max(8, int(k_main * 0.8))
        genes_main  = GENE_RANK[:k_main]
        genes_inter = genes_main[:min(k_int, k_main, MAX_NONCLIN - k_main)]
        Xtr_raw, feat_names = build_features_with_interactions(train_df, genes_main, genes_inter, act_col="Adjuvant Chemo")
        Xva_raw, _          = build_features_with_interactions(valid_df, genes_main, genes_inter, act_col="Adjuvant Chemo")
        in_dim_trial = Xtr_raw.shape[1]

    # --- TRAIN-only impute & standardize
    med = np.nanmedian(Xtr_raw, axis=0)
    Xtr = np.where(np.isnan(Xtr_raw), med, Xtr_raw)
    Xva = np.where(np.isnan(Xva_raw), med, Xva_raw)
    sc = StandardScaler().fit(Xtr)
    Xtr = sc.transform(Xtr).astype(np.float32)
    Xva = sc.transform(Xva).astype(np.float32)

    ytr_t = train_df["OS_MONTHS"].to_numpy(np.float32)
    ytr_e = train_df["OS_STATUS"].to_numpy(int)
    yva_t = valid_df["OS_MONTHS"].to_numpy(np.float32)
    yva_e = valid_df["OS_STATUS"].to_numpy(int)

    # --- IPTW on clinical pre-treatment covariates ONLY (exclude ACT)
    w_tr, ps_model, pi_tr = compute_iptw(train_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo")
    w_va, _, _ = compute_iptw(valid_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo",
                              ref_prevalence=pi_tr, model=ps_model)

    # --- Datasets/loaders
    bs = hp["batch_size"]
    tr_ds = SurvivalDataset(Xtr, ytr_t, ytr_e, weights=w_tr)
    va_ds = SurvivalDataset(Xva, yva_t, yva_e, weights=w_va)
    tr_sampler = EventBalancedBatchSampler(ytr_e, bs, seed=42)
    tr_loader = DataLoader(tr_ds, batch_sampler=tr_sampler, num_workers=0)
    tr_eval_loader = DataLoader(tr_ds, batch_size=bs, shuffle=False, num_workers=0)
    va_loader = DataLoader(va_ds, batch_size=bs, shuffle=False, num_workers=0)

    # --- Model/opt/sched
    in_dim_trial = Xtr.shape[1]
    model = DeepSurvMLP(in_dim_trial, layers, dropout=hp["dropout"]).to(device)
    opt = make_optimizer_groups(model, lr=hp["lr"], wd=hp["wd"], apply_final_wd=True)

    epochs = int(min(hp["epochs"], MAX_EPOCHS_CAP))
    if hp["sched"] == "cosine":
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
        def sched_step(i): sched.step()
    elif hp["sched"] == "cawr":
        sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, T_0=int(hp["cawr_T0"]), T_mult=int(hp["cawr_Tmult"]))
        def sched_step(i): sched.step(i+1)
    else:
        sched = None
        def sched_step(i): pass

    # --- Early-stopping on Val CI (tighter patience)
    PATIENCE = 16
    MIN_DELTA = 1e-4
    no_improve = 0
    best_val_ci = -np.inf
    best_tr_ci_at_best = float("nan")
    best_epoch = 0

    # indices for regularizers/dropout
    col_start = len(CLIN_FEATS)
    inter_start_idx = len(CLIN_FEATS) + k_main  # main genes end here; interactions start here

    for ep in range(epochs):
        # warm-ups
        frac_d = min(1.0, ep / float(WARMUP_EPOCHS_DROPOUT))
        frac_w = min(1.0, ep / float(WARMUP_EPOCHS_WD))
        set_dropout_p(model, DROPOUT_START + (hp['dropout'] - DROPOUT_START) * frac_d)
        set_weight_decay(opt, WD_START + (hp['wd'] - WD_START) * frac_w)

        st = train_one_epoch_reg(
            model, opt, tr_loader, device,
            l1_lambda=float(hp.get("l1", 0.0)), epoch=ep, warmup_epochs=WARMUP_EPOCHS_L1,
            input_dropout=float(hp.get("input_dropout", 0.0)),
            noise_std=float(hp.get("noise_std", 0.0)),
            grad_clip=float(hp.get("grad_clip", 5.0)),
            col_dropout_p=COL_DROPOUT_P, col_start=col_start,
            inter_start_idx=inter_start_idx if hp.get("use_l1", 0)==1 else None,
            lam_main=L1_LAM_MAIN, lam_inter=L1_LAM_INTER
        )
        _ = full_risk_set_step_reg(
            model, opt, tr_ds, device,
            l1_lambda=float(hp.get("l1", 0.0)), warm=st['warm'],
            input_dropout=float(hp.get("input_dropout", 0.0)),
            noise_std=float(hp.get("noise_std", 0.0)),
            grad_clip=float(hp.get("grad_clip", 5.0)),
            col_dropout_p=COL_DROPOUT_P, col_start=col_start,
            inter_start_idx=inter_start_idx if hp.get("use_l1", 0)==1 else None,
            lam_main=L1_LAM_MAIN, lam_inter=L1_LAM_INTER
        )
        sched_step(ep)

        # eval both splits so we can compute the gap at the best Val epoch
        va_ci = evaluate_ci(model, va_loader, device)
        tr_ci = evaluate_ci(model, tr_eval_loader, device)

        if va_ci > best_val_ci + MIN_DELTA:
            best_val_ci = va_ci
            best_tr_ci_at_best = tr_ci
            best_epoch = ep + 1
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE:
                break

    gap = max(0.0, best_tr_ci_at_best - best_val_ci)
    param_cnt = count_params(in_dim_trial, layers)

    # annotate & cleanup
    trial.set_user_attr("best_epoch", int(best_epoch))
    trial.set_user_attr("n_features", int(in_dim_trial))
    trial.set_user_attr("k_main", int(k_main))
    trial.set_user_attr("k_int", int(k_int))
    del model, opt, sched
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    gc.collect()

    # Multi-objective return: (Val CI ↑, Gap ↓, Params ↓)
    return float(best_val_ci), float(gap), int(param_cnt)

# ---- Optuna Study: NSGA-II (no pruner for multi-objective) ----
storage = "sqlite:///deepsurv_optuna.db"
study_name = f"deepsurv_cox_mo_gap_size_interactions_bounded_M{MAX_GENES}"

sampler = NSGAIISampler(seed=42, population_size=24)

study = optuna.create_study(
    directions=["maximize", "minimize", "minimize"],
    study_name=study_name,
    storage=storage,
    load_if_exists=True,
    sampler=sampler
)

# ---- Launch Optuna Dashboard (proxied URL printed) ----
try:
    from optuna_dashboard import run_server
    PORT = portpicker.pick_unused_port()
    def _start_dashboard():
        run_server(storage, host="0.0.0.0", port=PORT)
    t = threading.Thread(target=_start_dashboard, daemon=True)
    t.start(); time.sleep(2)
    dash_url = output.eval_js(f"google.colab.kernel.proxyPort({PORT}, {{'cache': false}})")
    print("Optuna Dashboard:", dash_url)
except Exception as ex:
    print("[Optuna Dashboard] Could not start dashboard automatically.", ex)
    print("Run locally:  optuna-dashboard sqlite:///deepsurv_optuna.db")

# ---- Optimize ----
N_TRIALS = 100
print(f"Starting multi-objective optimization: {N_TRIALS} trials")
study.optimize(objective, n_trials=N_TRIALS, gc_after_trial=True)

# ---- Choose a robust solution from Pareto front (conservative) ----
pareto = study.best_trials
best_val = max(tr.values[0] for tr in pareto)
TOL = 0.020  # within 2pp absolute C-index of best
cands = [tr for tr in pareto if (best_val - tr.values[0]) <= TOL and tr.values[2] <= PARAM_BUDGET]
if not cands:
    cands = [tr for tr in pareto if tr.values[2] <= PARAM_BUDGET] or pareto
cands.sort(key=lambda tr: (tr.values[1], tr.values[2]))  # gap, then params
chosen = cands[0]

print("\n[Chosen Pareto] Val CI=%.4f | Gap=%.4f | Params=%d" % (chosen.values[0], chosen.values[1], chosen.values[2]))
print("[Chosen Params]", chosen.params)
print("[Chosen Attrs] k_main=%s k_int=%s n_features=%s" %
      (str(chosen.user_attrs.get("k_main")), str(chosen.user_attrs.get("k_int")), str(chosen.user_attrs.get("n_features"))))

# ============================================================
# Final training on Train+Val with chosen hyperparams + interactions + IPTW
# (re-apply budgets and safe clamping)
# ============================================================
best_hp = chosen.params
# force WD on last layer
best_hp["apply_final_wd"] = 1

# Recompute k_main/k_int under budgets
MAX_NONCLIN = max(8, FEAT_BUDGET - len(CLIN_FEATS))
k_main = int(min(best_hp["top_k_genes"], MAX_NONCLIN))
k_int  = int(min(best_hp["top_k_inter"], k_main, MAX_NONCLIN - k_main))
genes_main  = GENE_RANK[:k_main]
genes_inter = genes_main[:k_int]

# Assemble Train+Val & Test
trainval_df = pd.concat([train_df, valid_df], axis=0, ignore_index=True)
trainval_df = trainval_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

X_trv_raw, feat_names = build_features_with_interactions(trainval_df, genes_main, genes_inter, act_col="Adjuvant Chemo")
X_te_raw,  _          = build_features_with_interactions(test_df,      genes_main, genes_inter, act_col="Adjuvant Chemo")

# If params exceed cap for the chosen arch, shrink further (interactions first)
in_dim_final = X_trv_raw.shape[1]
best_layers = layers_from_arch(best_hp["arch"])
while count_params(in_dim_final, best_layers) > PARAM_BUDGET and (k_int > 0 or k_main > 8):
    if k_int > 0:
        k_int = max(0, k_int // 2)
    else:
        k_main = max(8, int(k_main * 0.8))
    genes_main  = GENE_RANK[:k_main]
    genes_inter = genes_main[:min(k_int, k_main, MAX_NONCLIN - k_main)]
    X_trv_raw, feat_names = build_features_with_interactions(trainval_df, genes_main, genes_inter, act_col="Adjuvant Chemo")
    X_te_raw,  _          = build_features_with_interactions(test_df,      genes_main, genes_inter, act_col="Adjuvant Chemo")
    in_dim_final = X_trv_raw.shape[1]

print(f"[Final] Using features: {len(CLIN_FEATS)} clinical + {k_main} genes (main) + {k_int} interactions")

# Impute + standardize on Train+Val; apply to Test
med_trv = np.nanmedian(X_trv_raw, axis=0)
X_trv = np.where(np.isnan(X_trv_raw), med_trv, X_trv_raw)
X_te  = np.where(np.isnan(X_te_raw),  med_trv, X_te_raw)
sc_trv = StandardScaler().fit(X_trv)
X_trv = sc_trv.transform(X_trv).astype(np.float32)
X_te  = sc_trv.transform(X_te).astype(np.float32)

y_trv_t = trainval_df["OS_MONTHS"].to_numpy(np.float32)
y_trv_e = trainval_df["OS_STATUS"].to_numpy(int)
y_te_t  = test_df["OS_MONTHS"].to_numpy(np.float32)
y_te_e  = test_df["OS_STATUS"].to_numpy(int)

# IPTW on Train+Val, apply to Test with same prevalence
w_trv, ps_model_fin, pi_fin = compute_iptw(trainval_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo")
w_te, _, _ = compute_iptw(test_df, covariate_cols=CLIN_FEATS_PRETX, act_col="Adjuvant Chemo",
                          ref_prevalence=pi_fin, model=ps_model_fin)

# Loaders
bs_fin = int(best_hp["batch_size"])
ds_trv = SurvivalDataset(X_trv, y_trv_t, y_trv_e, weights=w_trv)
ds_te  = SurvivalDataset(X_te,  y_te_t,  y_te_e,  weights=w_te)
sam_trv = EventBalancedBatchSampler(y_trv_e, bs_fin, seed=7)
dl_trv  = DataLoader(ds_trv, batch_sampler=sam_trv, num_workers=0)
dl_trv_eval = DataLoader(ds_trv, batch_size=bs_fin, shuffle=False, num_workers=0)
dl_te   = DataLoader(ds_te,  batch_size=bs_fin, shuffle=False, num_workers=0)

# Model / optimizer / scheduler
model_final = DeepSurvMLP(in_dim_final, best_layers, dropout=float(best_hp["dropout"])).to(device)
opt_final = make_optimizer_groups(model_final, lr=float(best_hp["lr"]), wd=float(best_hp["wd"]),
                                  apply_final_wd=True)
epochs_fin = int(min(best_hp["epochs"], MAX_EPOCHS_CAP))
if best_hp["sched"] == "cosine":
    sched_final = torch.optim.lr_scheduler.CosineAnnealingLR(opt_final, T_max=epochs_fin)
    def sched_step(i): sched_final.step()
elif best_hp["sched"] == "cawr":
    sched_final = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        opt_final, T_0=int(best_hp["cawr_T0"]), T_mult=int(best_hp["cawr_Tmult"]))
    def sched_step(i): sched_final.step(i+1)
else:
    sched_final = None
    def sched_step(i): pass

# Indices for regularizers/dropout in final training
col_start_final = len(CLIN_FEATS)
inter_start_idx_final = len(CLIN_FEATS) + k_main

# Train
PATIENCE_FIN = 16
MIN_DELTA = 1e-4
no_improve = 0
best_trv_ci = -np.inf

for ep in range(epochs_fin):
    frac_d = min(1.0, ep / float(WARMUP_EPOCHS_DROPOUT))
    frac_w = min(1.0, ep / float(WARMUP_EPOCHS_WD))
    set_dropout_p(model_final, DROPOUT_START + (float(best_hp['dropout']) - DROPOUT_START) * frac_d)
    set_weight_decay(opt_final, WD_START + (float(best_hp['wd']) - WD_START) * frac_w)

    st = train_one_epoch_reg(
        model_final, opt_final, dl_trv, device,
        l1_lambda=float(best_hp.get("l1", 0.0)), epoch=ep, warmup_epochs=WARMUP_EPOCHS_L1,
        input_dropout=float(best_hp.get("input_dropout", 0.0)),
        noise_std=float(best_hp.get("noise_std", 0.0)),
        grad_clip=float(best_hp.get("grad_clip", 5.0)),
        col_dropout_p=COL_DROPOUT_P, col_start=col_start_final,
        inter_start_idx=inter_start_idx_final if best_hp.get("use_l1", 0)==1 else None,
        lam_main=L1_LAM_MAIN, lam_inter=L1_LAM_INTER
    )
    _ = full_risk_set_step_reg(
        model_final, opt_final, ds_trv, device,
        l1_lambda=float(best_hp.get("l1", 0.0)), warm=st['warm'],
        input_dropout=float(best_hp.get("input_dropout", 0.0)),
        noise_std=float(best_hp.get("noise_std", 0.0)),
        grad_clip=float(best_hp.get("grad_clip", 5.0)),
        col_dropout_p=COL_DROPOUT_P, col_start=col_start_final,
        inter_start_idx=inter_start_idx_final if best_hp.get("use_l1", 0)==1 else None,
        lam_main=L1_LAM_MAIN, lam_inter=L1_LAM_INTER
    )
    sched_step(ep)

    # mild ES on Train+Val CI to avoid late overfit
    trv_ci_now = evaluate_ci(model_final, dl_trv_eval, device)
    if trv_ci_now > best_trv_ci + MIN_DELTA:
        best_trv_ci = trv_ci_now
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE_FIN:
            break

# Evaluate
trainval_ci = evaluate_ci(model_final, dl_trv_eval, device)
test_ci = evaluate_ci(model_final, dl_te, device)
print(f"\n[Final] Train+Val CI: {trainval_ci:.4f}")
print(f"[Final] Test CI:      {test_ci:.4f}")

# Per-arm C-indices (sanity)
act_trv = trainval_df["Adjuvant Chemo"].to_numpy(int)
act_te  = test_df["Adjuvant Chemo"].to_numpy(int)
trv_grouped = evaluate_ci_grouped(model_final, X_trv, y_trv_t, y_trv_e, act_trv == 1)
te_grouped  = evaluate_ci_grouped(model_final, X_te,  y_te_t,  y_te_e,  act_te == 1)
print("[Train+Val] CI by arm:", trv_grouped)
print("[Test]      CI by arm:", te_grouped)

# Save artifacts
OUT_DIR = "/content/drive/MyDrive/deepsurv_results_optuna_interactions_iptw_bounded"
os.makedirs(OUT_DIR, exist_ok=True)
torch.save(model_final.state_dict(), os.path.join(OUT_DIR, "deepsurv_best.pt"))
with open(os.path.join(OUT_DIR, "chosen_params.txt"), "w") as f:
    f.write(str(best_hp))
with open(os.path.join(OUT_DIR, "features_used.txt"), "w") as f:
    f.write("\n".join(feat_names))
print("Saved final model and parameters to:", OUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
Genes.csv path: /content/drive/MyDrive/Genes.csv
[Genes] Selected 1555 genes with Prop == 1
[Features] Using 1573 common features → Clinical=18, Genes=1555


  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})


[Gene Ranking] Ranked 1555 genes on TRAIN
[Budgets] events(train)=348 → feature budget≤174 total inputs, param budget≤120,000


[I 2025-10-20 00:44:13,962] A new study created in RDB with name: deepsurv_cox_mo_gap_size_interactions_bounded_M1555
Bottle v0.13.4 server starting up (using WSGIRefServer())...
Listening on http://0.0.0.0:43147/
Hit Ctrl-C to quit.



Optuna Dashboard: https://43147-gpu-a100-s-ux6s319d4nv1-f.us-central1-1.prod.colab.dev
Starting multi-objective optimization: 100 trials


[I 2025-10-20 00:45:26,615] Trial 0 finished with values: [0.47919225666977844, 0.021841253440887443, 7489.0] and parameters: {'top_k_genes': 32, 'top_k_inter': 0, 'arch': '64-64', 'dropout': 0.1798695128633439, 'input_dropout': 0.07713516576204174, 'noise_std': 0.0473931655089634, 'wd': 1.4504865877614253e-06, 'use_l1': 0, 'lr': 1.2712078160994458e-05, 'sched': 'cawr', 'cawr_T0': 32, 'cawr_Tmult': 2, 'epochs': 96, 'batch_size': 128, 'grad_clip': 3.8114598712001184}.
[I 2025-10-20 00:45:46,445] Trial 1 finished with values: [0.6058074976661917, 0.027515306609713908, 13313.0] and parameters: {'top_k_genes': 192, 'top_k_inter': 128, 'arch': '64-32', 'dropout': 0.17948627261366898, 'input_dropout': 0.0008283175685403598, 'noise_std': 0.06523691427638674, 'wd': 0.00028696484378591143, 'use_l1': 1, 'l1': 1.9777828512462694e-07, 'lr': 3.7521794620277125e-05, 'sched': 'cawr', 'cawr_T0': 32, 'cawr_Tmult': 3, 'epochs': 224, 'batch_size': 64, 'grad_clip': 2.8371597215681117}.
[I 2025-10-20 00:45


[Chosen Pareto] Val CI=0.6499 | Gap=0.0000 | Params=1345
[Chosen Params] {'top_k_genes': 32, 'top_k_inter': 128, 'arch': '16', 'dropout': 0.24238907146050465, 'input_dropout': 0.11367691656965537, 'noise_std': 0.0011514790903804696, 'wd': 2.532786865696868e-06, 'use_l1': 0, 'lr': 0.00023469248087216023, 'sched': 'cosine', 'epochs': 160, 'batch_size': 32, 'grad_clip': 4.789533140781614}
[Chosen Attrs] k_main=32 k_int=32 n_features=82
[Final] Using features: 18 clinical + 32 genes (main) + 32 interactions

[Final] Train+Val CI: 0.6798
[Final] Test CI:      0.6027
[Train+Val] CI by arm: {'ACT=1': 0.5908468386213194, 'ACT=0': 0.705384074670558}
[Test]      CI by arm: {'ACT=1': 0.5933682373472949, 'ACT=0': 0.617439006785044}
Saved final model and parameters to: /content/drive/MyDrive/deepsurv_results_optuna_interactions_iptw_bounded
