<a href="https://colab.research.google.com/github/osun24/nsclc-adj-chemo/blob/main/TorchSurv_DeepSurv_with_Optuna.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Install necessary packages
!pip install torchsurv scikit-survival

# Import required packages
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sksurv.metrics import concordance_index_censored

# (Optional) Mount Google Drive if you plan to load/save files there
from google.colab import drive
drive.mount('/content/drive')

from optuna_dashboard import run_server
import threading, time, portpicker
from google.colab import output

PORT = portpicker.pick_unused_port()

def _serve():
    run_server("sqlite:///deepsurv_optuna.db", host="0.0.0.0", port=PORT)

threading.Thread(target=_serve, daemon=True).start()
time.sleep(2)
print("Dashboard:", output.eval_js(f"google.colab.kernel.proxyPort({PORT}, {{'cache': false}})"))


Collecting torchsurv
  Downloading torchsurv-0.1.5-py3-none-any.whl.metadata (15 kB)
Collecting torchmetrics (from torchsurv)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics->torchsurv)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchsurv-0.1.5-py3-none-any.whl (52 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.8/52.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics, torchsurv
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.2 torchsurv-0.1.5
Drive already mounted at /content/drive; to attempt to forcibly remount, cal

In [3]:
# ============================================================
# Colab-ready single cell: DeepSurv + Optuna HPO + Dashboard
# - Keeps ONLY requested clinical vars + genes with Prop==1
# - Sorts Train/Val by OS_MONTHS/OS_STATUS (desc)
# - Standardizes using TRAIN-only (applies to VAL); after HPO
#   restandardizes on TRAIN+VAL and evaluates TEST C-index
# - Optuna + Successive Halving pruner
# - Optuna Dashboard (proxied URL printed)  [fixed run_server args]
# - Encodes architectures as strings to avoid Optuna warning
# ============================================================

# ---------- Installs (Colab) ----------
!pip -q install optuna optuna-dashboard scikit-survival portpicker

# ---------- (Optional) Mount Google Drive ----------
from google.colab import drive
drive.mount('/content/drive')

# ---------- Imports ----------
import os, math, copy, warnings, random, gc, time, threading
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler

# Cox loss: Prefer torchsurv if available; fallback to Efron implementation
try:
    from torchsurv.loss.cox import neg_partial_log_likelihood
    _HAS_TORCHSURV = True
except Exception:
    _HAS_TORCHSURV = False

from sklearn.preprocessing import StandardScaler
from sksurv.metrics import concordance_index_censored

import optuna
from optuna.pruners import SuccessiveHalvingPruner

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")

# Reproducibility
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ============================================================
# Cox loss fallback (Efron) if torchsurv isn't available
# ============================================================
def _cox_negloglik_efron(pred, event, time):
    eta = pred.reshape(-1)
    e = event.to(torch.float32).reshape(-1)
    t = time.reshape(-1)

    order = torch.argsort(t, descending=True)
    t = t[order]; e = e[order]; eta = eta[order]
    exp_eta = torch.exp(eta)
    cum_exp = torch.cumsum(exp_eta, dim=0)

    uniq_mask = torch.ones_like(t, dtype=torch.bool)
    uniq_mask[1:] = t[1:] != t[:-1]
    idxs = torch.nonzero(uniq_mask, as_tuple=False).reshape(-1)
    idxs = torch.cat([idxs, torch.tensor([len(t)], device=t.device)])

    nll = torch.tensor(0.0, device=t.device)
    for k in range(len(idxs)-1):
        start, end = idxs[k].item(), idxs[k+1].item()
        e_slice = e[start:end]
        d = int(e_slice.sum().item())
        if d == 0: continue
        eta_events = eta[start:end][e_slice.bool()]
        exp_events = exp_eta[start:end][e_slice.bool()]
        s_eta = eta_events.sum()
        risk_sum = cum_exp[end-1]
        s_exp = exp_events.sum()
        eps = 1e-12
        log_terms = 0.0
        for j in range(d):
            log_terms = log_terms + torch.log(risk_sum - (j / d) * s_exp + eps)
        nll = nll - (s_eta - log_terms)
    return nll / t.numel()

def cox_negloglik(pred, event, time):
    if _HAS_TORCHSURV:
        return neg_partial_log_likelihood(pred, event, time, reduction='mean')
    return _cox_negloglik_efron(pred, event, time)

# ============================================================
# Model, Dataset, Sampler, Utilities
# ============================================================
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events.astype(bool), dtype=torch.bool)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx]

class EventBalancedBatchSampler(Sampler):
    def __init__(self, events_numpy, batch_size, seed=0):
        events = np.asarray(events_numpy).astype(bool)
        self.pos_idx = np.where(events)[0]
        self.neg_idx = np.where(~events)[0]
        assert len(self.pos_idx) > 0, "No events in training set — cannot balance batches."
        self.bs = int(batch_size)
        self.rng = np.random.default_rng(seed)
    def __iter__(self):
        pos = self.rng.permutation(self.pos_idx)
        neg = self.rng.permutation(self.neg_idx)
        n_total = len(pos) + len(neg)
        n_batches = math.ceil(n_total / self.bs)
        pi = ni = 0
        for _ in range(n_batches):
            take_pos = 1 if pi < len(pos) else 0
            avail_neg = max(0, len(neg) - ni)
            take_neg = min(self.bs - take_pos, avail_neg)
            need = self.bs - (take_pos + take_neg)
            extra_pos = min(need, max(0, len(pos) - (pi + take_pos)))
            take_pos += extra_pos
            batch = np.concatenate([pos[pi:pi+take_pos], neg[ni:ni+take_neg]])
            pi += take_pos; ni += take_neg
            if batch.size == 0: break
            self.rng.shuffle(batch)
            yield batch.tolist()
    def __len__(self):
        return math.ceil((len(self.pos_idx) + len(self.neg_idx)) / self.bs)

def make_optimizer(model, lr, wd):
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        if name.endswith('bias'):
            no_decay.append(p); continue
        if (last_linear is not None) and (p is last_linear.weight):
            no_decay.append(p); continue
        decay.append(p)
    param_groups = [{'params': decay, 'weight_decay': wd},
                    {'params': no_decay, 'weight_decay': 0.0}]
    return optim.AdamW(param_groups, lr=lr)

def set_dropout_p(model, p):
    for m in model.modules():
        if isinstance(m, nn.Dropout): m.p = float(p)

def set_weight_decay(optimizer, wd):
    for g in optimizer.param_groups: g['weight_decay'] = float(wd)

def l1_penalty_first_layer(model):
    for m in model.modules():
        if isinstance(m, nn.Linear): return m.weight.abs().sum()
    return torch.tensor(0.0, device=next(model.parameters()).device)

@torch.no_grad()
def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    for x, t, e in dataloader:
        y = torch.clamp(model(x.to(device)), -20, 20)
        preds.append(y.cpu().numpy().ravel())
        times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    times = np.concatenate(times); events = np.concatenate(events)
    if np.isnan(preds).any(): return -np.inf
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def train_one_epoch(model, optimizer, dataloader, device, l1_lambda=0.0, epoch=0, warmup_epochs=20):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))
    loss_sum, n_seen = 0.0, 0
    for x, t, e in dataloader:
        if e.sum().item() == 0:  # safety (shouldn't happen with balanced sampler)
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)
        optimizer.zero_grad(set_to_none=True)
        out = torch.clamp(model(x), -20, 20)
        loss = cox_negloglik(out, e, t)
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        n_seen += x.size(0)
    return {'avg_loss': loss_sum / max(n_seen, 1), 'warm': warm}

def full_risk_set_step(model, optimizer, ds, device, l1_lambda=0.0, warm=1.0):
    model.train()
    X_all = ds.x.to(device); t_all = ds.time.to(device); e_all = ds.event.to(device)
    optimizer.zero_grad(set_to_none=True)
    out_all = torch.clamp(model(X_all), -20, 20)
    loss_full = cox_negloglik(out_all, e_all, t_all)
    if l1_lambda > 0:
        loss_full = loss_full + (l1_lambda * warm) * l1_penalty_first_layer(model)
    loss_full.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
    optimizer.step()
    return float(loss_full.detach().cpu().item())

# ============================================================
# Data loading & preprocessing
# ============================================================
# Default CSV paths (adjust if needed)
TRAIN_CSV = "/content/drive/MyDrive/affyfRMATrain.csv"
VALID_CSV = "/content/drive/MyDrive/affyfRMAValidation.csv"
TEST_CSV  = "/content/drive/MyDrive/affyfRMATest.csv"

# Genes list path (uploaded or Drive)
GENES_CSV = "/mnt/data/Genes.csv"
if not os.path.exists(GENES_CSV):
    if os.path.exists("/content/Genes.csv"):
        GENES_CSV = "/content/Genes.csv"
    elif os.path.exists("/content/drive/MyDrive/Genes.csv"):
        GENES_CSV = "/content/drive/MyDrive/Genes.csv"
print("Genes.csv path:", GENES_CSV)

# Clinical variables to KEEP (exact names)
CLINICAL_VARS = [
    "Adjuvant Chemo","Age","IS_MALE",
    "Stage_IA","Stage_IB","Stage_II","Stage_III",
    "Histology_Adenocarcinoma","Histology_Large Cell Carcinoma","Histology_Squamous Cell Carcinoma",
    "Race_African American","Race_Asian","Race_Caucasian","Race_Native Hawaiian or Other Pacific Islander","Race_Unknown",
    "Smoked?_No","Smoked?_Unknown","Smoked?_Yes"
]

def load_genes_list(genes_csv):
    g = pd.read_csv(genes_csv)
    if not {"Gene","Prop"}.issubset(set(g.columns)):
        raise ValueError(f"Genes.csv must contain 'Gene' and 'Prop' columns. Found: {list(g.columns)}")
    genes = g.loc[g["Prop"] == 1, "Gene"].astype(str).tolist()
    print(f"[Genes] Selected {len(genes)} genes with Prop == 1")
    return genes

def coerce_survival_cols(df):
    # Map to integers {0,1}
    if df["OS_STATUS"].dtype == object:
        df["OS_STATUS"] = df["OS_STATUS"].replace({"DECEASED":1,"LIVING":0,"Dead":1,"Alive":0}).astype(int)
    else:
        df["OS_STATUS"] = pd.to_numeric(df["OS_STATUS"], errors="coerce").fillna(0).astype(int)
    df["OS_MONTHS"] = pd.to_numeric(df["OS_MONTHS"], errors="coerce").fillna(0.0).astype(float)
    return df

def preprocess_split(df, clinical_vars, gene_names):
    if "Adjuvant Chemo" in df.columns:
        df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
    for col in ["Adjuvant Chemo","IS_MALE"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int)
    df = coerce_survival_cols(df)
    keep_cols = [c for c in clinical_vars if c in df.columns] + [g for g in gene_names if g in df.columns]
    missing_clin = [c for c in clinical_vars if c not in df.columns]
    if missing_clin:
        print(f"[WARN] Missing clinical columns: {missing_clin}")
    if len(keep_cols) == 0:
        raise ValueError("No feature columns found after filtering clinical+genes.")
    cols = ["OS_STATUS","OS_MONTHS"] + keep_cols
    return df[cols].copy()

# Load CSVs
train_raw = pd.read_csv(TRAIN_CSV)
valid_raw = pd.read_csv(VALID_CSV)
test_raw  = pd.read_csv(TEST_CSV)

# Load genes (Prop==1)
GENE_LIST = load_genes_list(GENES_CSV)

# Reduce to requested columns per split
train_df = preprocess_split(train_raw, CLINICAL_VARS, GENE_LIST)
valid_df = preprocess_split(valid_raw, CLINICAL_VARS, GENE_LIST)
test_df  = preprocess_split(test_raw,  CLINICAL_VARS, GENE_LIST)

# Ensure consistent columns across splits (intersection)
feat_candidates = [c for c in (CLINICAL_VARS + GENE_LIST)
                   if c in train_df.columns and c in valid_df.columns and c in test_df.columns]
if len(feat_candidates) == 0:
    raise ValueError("After filtering, no common features across train/val/test.")
print(f"[Features] Using {len(feat_candidates)} common features.")

# Sort Train/Val by event time & status (descending)
train_df = train_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)
valid_df = valid_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

# Build arrays & TRAIN-only standardization (apply to VAL)
X_train = train_df[feat_candidates].values.astype(np.float32)
X_valid = valid_df[feat_candidates].values.astype(np.float32)

train_medians = np.nanmedian(X_train, axis=0)
X_train = np.where(np.isnan(X_train), train_medians, X_train)
X_valid = np.where(np.isnan(X_valid), train_medians, X_valid)

scaler_tv = StandardScaler().fit(X_train)
X_train = scaler_tv.transform(X_train).astype(np.float32)
X_valid = scaler_tv.transform(X_valid).astype(np.float32)

ytr_time = train_df["OS_MONTHS"].values.astype(np.float32)
ytr_event = train_df["OS_STATUS"].values.astype(int)
yva_time = valid_df["OS_MONTHS"].values.astype(np.float32)
yva_event = valid_df["OS_STATUS"].values.astype(int)

BATCH_SIZE = 64
train_ds = SurvivalDataset(X_train, ytr_time, ytr_event)
valid_ds = SurvivalDataset(X_valid, yva_time, yva_event)

train_sampler = EventBalancedBatchSampler(ytr_event, BATCH_SIZE, seed=42)
train_loader  = DataLoader(train_ds, batch_sampler=train_sampler, num_workers=0)
train_eval_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
valid_loader      = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

in_dim = X_train.shape[1]
print("Input dim:", in_dim)

# ============================================================
# Optuna Objective & Study  (arch encoded as strings -> parsed)
# ============================================================# ============================================================
# Optuna Objective & Study — Expanded Space to Reduce Overfitting
#   * Architectures include smaller, narrower and bottlenecks
#   * Stronger regularizers: higher dropout, input dropout, L1/L2
#   * Optionally apply WD to final layer
#   * Batch size, grad clip, scheduler, epochs per trial
#   * Input Gaussian noise
# ============================================================

# Wider set of architectures (strings -> parsed lists)
ARCH_CHOICES = (
    # very small
    "16", "32",
    # small / medium singles
    "64", "128", "256",
    # conservative big single
    "512",
    # 2-layer bottlenecks and symmetric
    "32-16", "32-32", "64-32", "64-64",
    "128-64", "128-128", "256-128", "256-256",
    "512-256", "512-512",
    # 3-layer (narrowing)
    "64-32-16", "128-64-32", "256-128-64"
)

def layers_from_arch(arch_str: str):
    return [int(x) for x in arch_str.split("-") if x.strip()]

def suggest_hparams(trial):
    arch = trial.suggest_categorical("arch", ARCH_CHOICES)

    # Regularization knobs
    dropout = trial.suggest_float("dropout", 0.10, 0.70)             # ↑ upper bound
    input_dropout = trial.suggest_float("input_dropout", 0.00, 0.30) # feature dropout before first layer

    # L2 (weight decay): allow much stronger; optionally apply to final layer too
    wd = trial.suggest_float("wd", 1e-6, 1e-1, log=True)
    apply_final_wd = trial.suggest_categorical("apply_final_wd", (0, 1))

    # L1: allow disabling OR stronger values
    use_l1 = trial.suggest_categorical("use_l1", (0, 1))
    l1 = 0.0 if use_l1 == 0 else trial.suggest_float("l1", 1e-8, 3e-3, log=True)

    # Optim & schedule
    lr = trial.suggest_float("lr", 1e-5, 5e-4, log=True)
    sched = trial.suggest_categorical("sched", ("cosine", "cawr", "none"))
    if sched == "cawr":
        cawr_T0 = trial.suggest_int("cawr_T0", 16, 80, step=8)
        cawr_Tmult = trial.suggest_categorical("cawr_Tmult", (1, 2, 3))
    else:
        cawr_T0, cawr_Tmult = None, None

    # Training controls
    epochs = trial.suggest_int("epochs", 64, 512, step=32)           # per-trial budget
    batch_size = trial.suggest_categorical("batch_size", (32, 64, 128))
    grad_clip = trial.suggest_float("grad_clip", 1.0, 10.0)          # stability & regularization

    # Data regularization
    noise_std = trial.suggest_float("noise_std", 0.0, 0.10)          # Gaussian feature noise

    return {
        "arch": arch,
        "dropout": dropout,
        "input_dropout": input_dropout,
        "lr": lr,
        "wd": wd,
        "apply_final_wd": apply_final_wd,
        "l1": l1,
        "sched": sched,
        "cawr_T0": cawr_T0,
        "cawr_Tmult": cawr_Tmult,
        "epochs": epochs,
        "batch_size": batch_size,
        "grad_clip": grad_clip,
        "noise_std": noise_std,
    }

# Warmups (as before)
MAX_EPOCHS_CAP = 512  # absolute cap (safety)
WARMUP_EPOCHS_L1 = 30
WARMUP_EPOCHS_DROPOUT = 30
WARMUP_EPOCHS_WD = 30
DROPOUT_START = 0.15
WD_START = 0.0

# Local helpers that add input dropout/noise and variable grad clip
def _apply_input_dropout(x, p):
    if p <= 0.0: return x
    # inverted dropout on features
    keep = 1.0 - p
    mask = torch.bernoulli(torch.full_like(x, keep))
    return x * mask / max(keep, 1e-6)

def train_one_epoch_reg(model, optimizer, dataloader, device,
                        l1_lambda=0.0, epoch=0, warmup_epochs=20,
                        input_dropout=0.0, noise_std=0.0, grad_clip=5.0):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))
    loss_sum, n_seen = 0.0, 0
    for x, t, e in dataloader:
        if e.sum().item() == 0:
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)

        # data-level regularization
        if input_dropout > 0.0:
            x = _apply_input_dropout(x, input_dropout)
        if noise_std > 0.0:
            x = x + noise_std * torch.randn_like(x)

        optimizer.zero_grad(set_to_none=True)
        out = torch.clamp(model(x), -20, 20)
        loss = cox_negloglik(out, e, t)
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        n_seen += x.size(0)
    return {'avg_loss': loss_sum / max(n_seen, 1), 'warm': warm}

def full_risk_set_step_reg(model, optimizer, ds, device,
                           l1_lambda=0.0, warm=1.0,
                           input_dropout=0.0, noise_std=0.0, grad_clip=5.0):
    model.train()
    X_all = ds.x.to(device); t_all = ds.time.to(device); e_all = ds.event.to(device)
    XX = X_all
    if input_dropout > 0.0:
        XX = _apply_input_dropout(XX, input_dropout)
    if noise_std > 0.0:
        XX = XX + noise_std * torch.randn_like(XX)

    optimizer.zero_grad(set_to_none=True)
    out_all = torch.clamp(model(XX), -20, 20)
    loss_full = cox_negloglik(out_all, e_all, t_all)
    if l1_lambda > 0:
        loss_full = loss_full + (l1_lambda * warm) * l1_penalty_first_layer(model)
    loss_full.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_clip))
    optimizer.step()
    return float(loss_full.detach().cpu().item())

# Optimizer that can optionally apply WD to the final layer
def make_optimizer_hpo(model, lr, wd, apply_final_wd=False):
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.endswith('bias'):
            no_decay.append(p); continue
        if (last_linear is not None) and (p is last_linear.weight) and not apply_final_wd:
            no_decay.append(p); continue
        decay.append(p)
    param_groups = [{'params': decay, 'weight_decay': float(wd)},
                    {'params': no_decay, 'weight_decay': 0.0}]
    return optim.AdamW(param_groups, lr=float(lr))

def objective(trial):
    hp = suggest_hparams(trial)
    layers = layers_from_arch(hp["arch"])

    # Build loaders with TRIAL-SPECIFIC batch size (smaller batches often regularize more)
    bs = int(hp["batch_size"])
    tr_sampler = EventBalancedBatchSampler(ytr_event, bs, seed=42)
    tr_loader = DataLoader(train_ds, batch_sampler=tr_sampler, num_workers=0)
    tr_eval_loader = DataLoader(train_ds, batch_size=bs, shuffle=False, num_workers=0)
    va_loader = DataLoader(valid_ds, batch_size=bs, shuffle=False, num_workers=0)

    # Model / optimizer
    model = DeepSurvMLP(in_dim, layers, dropout=hp["dropout"]).to(device)
    optimizer = make_optimizer_hpo(model, lr=hp["lr"], wd=hp["wd"],
                                   apply_final_wd=bool(hp["apply_final_wd"]))

    # Scheduler per trial
    epochs = int(min(hp["epochs"], MAX_EPOCHS_CAP))
    if hp["sched"] == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        def step_sched(epoch_idx): scheduler.step()
    elif hp["sched"] == "cawr":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=int(hp["cawr_T0"]), T_mult=int(hp["cawr_Tmult"])
        )
        def step_sched(epoch_idx): scheduler.step(epoch_idx + 1)
    else:
        scheduler = None
        def step_sched(epoch_idx): pass

    best_val_ci = -np.inf
    best_epoch = 0

    for epoch in range(epochs):
        # Warm-up schedules for dropout & WD
        frac_d = min(1.0, epoch / float(WARMUP_EPOCHS_DROPOUT))
        frac_w = min(1.0, epoch / float(WARMUP_EPOCHS_WD))
        set_dropout_p(model, DROPOUT_START + (hp['dropout'] - DROPOUT_START) * frac_d)
        set_weight_decay(optimizer, WD_START + (hp['wd'] - WD_START) * frac_w)

        # One epoch + full risk-set step with extra regularizers
        stats = train_one_epoch_reg(
            model, optimizer, tr_loader, device,
            l1_lambda=hp["l1"], epoch=epoch, warmup_epochs=WARMUP_EPOCHS_L1,
            input_dropout=hp["input_dropout"], noise_std=hp["noise_std"],
            grad_clip=hp["grad_clip"]
        )
        _ = full_risk_set_step_reg(
            model, optimizer, train_ds, device,
            l1_lambda=hp["l1"], warm=stats['warm'],
            input_dropout=hp["input_dropout"], noise_std=hp["noise_std"],
            grad_clip=hp["grad_clip"]
        )

        # Evaluate
        val_ci = evaluate_ci(model, va_loader, device)
        step_sched(epoch)

        # Report for pruning
        trial.report(val_ci, step=epoch)
        if val_ci > best_val_ci:
            best_val_ci = val_ci
            best_epoch = epoch + 1

        if trial.should_prune():
            # Clean up GPU mem before pruning
            del model, optimizer, scheduler
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            gc.collect()
            raise optuna.TrialPruned()

    trial.set_user_attr("best_epoch", int(best_epoch))

    # Clean up
    del model, optimizer, scheduler
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    gc.collect()
    return best_val_ci

# ---- Study / Sampler / Pruner (tune TPE a bit for broader exploration) ----
storage = "sqlite:///deepsurv_optuna.db"
study_name = "deepsurv_cox_hpo_overfit_reducer"

sampler = optuna.samplers.TPESampler(
    seed=42,
    multivariate=True,
    group=True,
    n_startup_trials=40,       # explore more before exploitation
    constant_liar=True,        # better parallel behavior if you run parallel
    consider_prior=True
)

# Hyperband/SH both work; Hyperband gives a bit more flexibility across resource levels
pruner = optuna.pruners.HyperbandPruner(min_resource=16, reduction_factor=3)

study = optuna.create_study(
    direction="maximize",       # single objective (Val CI)
    study_name=study_name,
    storage=storage,
    load_if_exists=True,
    sampler=sampler,
    pruner=pruner
)


# ============================================================
# Launch Optuna Dashboard (Colab proxied URL printed)
#   (No unsupported args; previous TypeError fixed)
# ============================================================
try:
    from optuna_dashboard import run_server
    from google.colab import output
    import portpicker
    PORT = portpicker.pick_unused_port()
    def _start_dashboard():
        # NOTE: do NOT pass unsupported kwargs like reload/quiet
        run_server(storage, host="0.0.0.0", port=PORT)
    t = threading.Thread(target=_start_dashboard, daemon=True)
    t.start()
    time.sleep(2)
    dash_url = output.eval_js(f"google.colab.kernel.proxyPort({PORT}, {{'cache': false}})")
    print("Optuna Dashboard:", dash_url)
except Exception as ex:
    print("[Optuna Dashboard] Could not start dashboard automatically.", ex)
    print("You can run it locally with:  optuna-dashboard sqlite:///deepsurv_optuna.db")

# ============================================================
# Run Optimization
# ============================================================
N_TRIALS = 100  # adjust as needed
print(f"Starting optimization: {N_TRIALS} trials × up to {MAX_EPOCHS_CAP} epochs")
study.optimize(objective, n_trials=N_TRIALS, gc_after_trial=True)

print("\n[Best] Val CI:", study.best_value)
print("[Best] Params:", study.best_params)
print("[Best] Best epoch:", study.best_trial.user_attrs.get("best_epoch", MAX_EPOCHS_CAP))

# ============================================================
# Retrain on Train+Val with best hyperparams; evaluate Test
# - Combine Train+Val, sort, restandardize; apply to Test
# ============================================================
# Prepare Train+Val
trainval_df = pd.concat([train_df, valid_df], axis=0, ignore_index=True)
trainval_df = trainval_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

X_trv = trainval_df[feat_candidates].values.astype(np.float32)
y_trv_time = trainval_df["OS_MONTHS"].values.astype(np.float32)
y_trv_event = trainval_df["OS_STATUS"].values.astype(int)

# Median impute by Train+Val medians (for retraining phase)
trv_medians = np.nanmedian(X_trv, axis=0)
X_trv = np.where(np.isnan(X_trv), trv_medians, X_trv)

scaler_trv = StandardScaler().fit(X_trv)
X_trv = scaler_trv.transform(X_trv).astype(np.float32)

# Test set standardized with Train+Val scaler
X_test = test_df[feat_candidates].values.astype(np.float32)
X_test = np.where(np.isnan(X_test), trv_medians, X_test)
X_test = scaler_trv.transform(X_test).astype(np.float32)
y_te_time = test_df["OS_MONTHS"].values.astype(np.float32)
y_te_event = test_df["OS_STATUS"].values.astype(int)

# Loaders
BATCH_SIZE = 64
trv_ds = SurvivalDataset(X_trv, y_trv_time, y_trv_event)
te_ds  = SurvivalDataset(X_test, y_te_time, y_te_event)

trv_sampler = EventBalancedBatchSampler(y_trv_event, BATCH_SIZE, seed=7)
trv_loader  = DataLoader(trv_ds, batch_sampler=trv_sampler, num_workers=0)
trv_eval_loader = DataLoader(trv_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
te_loader  = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Build & train final model
best_hp = study.best_params
l1_final = float(best_hp.get("l1", 0.0))
best_layers = layers_from_arch(best_hp["arch"])
best_n_epochs = int(study.best_trial.user_attrs.get("best_epoch", MAX_EPOCHS_CAP))
best_n_epochs = max(16, min(best_n_epochs, MAX_EPOCHS_CAP))

model_final = DeepSurvMLP(in_dim, best_layers, dropout=best_hp["dropout"]).to(device)
opt_final = make_optimizer(model_final, lr=best_hp["lr"], wd=best_hp["wd"])
sched_final = torch.optim.lr_scheduler.CosineAnnealingLR(opt_final, T_max=best_n_epochs)

for epoch in range(best_n_epochs):
    frac_d = min(1.0, epoch / float(WARMUP_EPOCHS_DROPOUT))
    frac_w = min(1.0, epoch / float(WARMUP_EPOCHS_WD))
    set_dropout_p(model_final, DROPOUT_START + (best_hp['dropout'] - DROPOUT_START) * frac_d)
    set_weight_decay(opt_final, WD_START + (best_hp['wd'] - WD_START) * frac_w)

    stats = train_one_epoch(model_final, opt_final, trv_loader, device, l1_lambda=best_hp["l1"],
                            epoch=epoch, warmup_epochs=WARMUP_EPOCHS_L1)
    _ = full_risk_set_step(model_final, opt_final, trv_ds, device, l1_lambda=best_hp["l1"], warm=stats['warm'])
    sched_final.step()

# Evaluate
trainval_ci = evaluate_ci(model_final, trv_eval_loader, device)
test_ci = evaluate_ci(model_final, te_loader, device)
print(f"\n[Final] Train+Val CI: {trainval_ci:.4f}")
print(f"[Final] Test CI:      {test_ci:.4f}")

# (Optional) Save artifacts to Drive
OUT_DIR = "/content/drive/MyDrive/deepsurv_results_optuna"
os.makedirs(OUT_DIR, exist_ok=True)
torch.save(model_final.state_dict(), os.path.join(OUT_DIR, "deepsurv_best.pt"))
with open(os.path.join(OUT_DIR, "best_params.txt"), "w") as f:
    f.write(str(study.best_params))
print("Saved final model and best params to:", OUT_DIR)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Device: cuda
Genes.csv path: /content/drive/MyDrive/Genes.csv
[Genes] Selected 1555 genes with Prop == 1
[Features] Using 1573 common features.


  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})
  df["Adjuvant Chemo"] = df["Adjuvant Chemo"].replace({"OBS":0, "ACT":1})


Input dim: 1573


[I 2025-10-19 05:02:14,045] A new study created in RDB with name: deepsurv_cox_hpo_overfit_reducer
Bottle v0.13.4 server starting up (using WSGIRefServer())...
Listening on http://0.0.0.0:40971/
Hit Ctrl-C to quit.



Optuna Dashboard: https://40971-gpu-t4-s-104ldvpjndb22-c.asia-southeast1-1.prod.colab.dev
Starting optimization: 100 trials × up to 512 epochs


127.0.0.1 - - [19/Oct/2025 05:02:24] "GET / HTTP/1.1" 302 0
127.0.0.1 - - [19/Oct/2025 05:02:30] "GET / HTTP/1.1" 302 0
127.0.0.1 - - [19/Oct/2025 05:02:35] "GET / HTTP/1.1" 302 0
127.0.0.1 - - [19/Oct/2025 05:03:57] "GET / HTTP/1.1" 302 0
[I 2025-10-19 05:04:51,571] Trial 0 finished with value: 0.6583796000589593 and parameters: {'arch': '128-128', 'dropout': 0.27473748411882515, 'input_dropout': 0.18355586841671384, 'wd': 4.982752357076453e-06, 'apply_final_wd': 1, 'use_l1': 1, 'l1': 1.240616395320886e-07, 'lr': 7.475992999956501e-05, 'sched': 'none', 'epochs': 128, 'batch_size': 128, 'grad_clip': 8.275576133048151, 'noise_std': 0.03046137691733707}. Best is trial 0 with value: 0.6583796000589593.
[I 2025-10-19 05:06:08,099] Trial 1 finished with value: 0.6681570284478947 and parameters: {'arch': '256-256', 'dropout': 0.1530955012311517, 'input_dropout': 0.058794858725743554, 'wd': 1.6832027985721922e-06, 'apply_final_wd': 1, 'use_l1': 1, 'l1': 8.994587030462112e-07, 'lr': 3.00123018


[Best] Val CI: 0.6930673610769911
[Best] Params: {'arch': '512', 'dropout': 0.385912835673872, 'input_dropout': 0.061771953713056, 'wd': 3.3415205446076416e-06, 'apply_final_wd': 0, 'use_l1': 0, 'lr': 6.945321990789577e-05, 'sched': 'cosine', 'epochs': 192, 'batch_size': 32, 'grad_clip': 8.251469896233349, 'noise_std': 0.09246940755563608}
[Best] Best epoch: 26


KeyError: 'l1'

In [4]:
# ==== Ready-to-run: Retrain on Train+Val with given best params, then eval on Test ====

# --- Use the reported best params & epoch (handles missing 'l1') ---
best_hp = {
    'arch': '512',
    'dropout': 0.385912835673872,
    'input_dropout': 0.061771953713056,  # not used in final fit below
    'wd': 3.3415205446076416e-06,
    'apply_final_wd': 0,                 # not used in final fit below
    'use_l1': 0,
    'lr': 6.945321990789577e-05,
    'sched': 'cosine',
    'epochs': 192,
    'batch_size': 32,
    'grad_clip': 8.251469896233349,      # not used in final fit below
    'noise_std': 0.09246940755563608     # not used in final fit below
}
best_n_epochs = 26
l1_final = float(best_hp.get("l1", 0.0))  # <- fix: default L1 when use_l1==0

# --- Build Train+Val, restandardize; transform Test ---
trainval_df = pd.concat([train_df, valid_df], axis=0, ignore_index=True)
trainval_df = trainval_df.sort_values(by=["OS_MONTHS","OS_STATUS"], ascending=[False, False]).reset_index(drop=True)

X_trv = trainval_df[feat_candidates].values.astype(np.float32)
y_trv_time = trainval_df["OS_MONTHS"].values.astype(np.float32)
y_trv_event = trainval_df["OS_STATUS"].values.astype(int)

trv_medians = np.nanmedian(X_trv, axis=0)
X_trv = np.where(np.isnan(X_trv), trv_medians, X_trv)
scaler_trv = StandardScaler().fit(X_trv)
X_trv = scaler_trv.transform(X_trv).astype(np.float32)

X_test = test_df[feat_candidates].values.astype(np.float32)
X_test = np.where(np.isnan(X_test), trv_medians, X_test)
X_test = scaler_trv.transform(X_test).astype(np.float32)
y_te_time = test_df["OS_MONTHS"].values.astype(np.float32)
y_te_event = test_df["OS_STATUS"].values.astype(int)

# --- Datasets & Loaders ---
BATCH_SIZE = 64
trv_ds = SurvivalDataset(X_trv, y_trv_time, y_trv_event)
te_ds  = SurvivalDataset(X_test, y_te_time, y_te_event)

trv_sampler = EventBalancedBatchSampler(y_trv_event, BATCH_SIZE, seed=7)
trv_loader  = DataLoader(trv_ds, batch_sampler=trv_sampler, num_workers=0)
trv_eval_loader = DataLoader(trv_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
te_loader  = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# --- Model (arch parsed from string), Optim, Scheduler ---
def layers_from_arch(arch_str: str):
    return [int(x) for x in arch_str.split("-") if x.strip()]

in_dim = X_trv.shape[1]
best_layers = layers_from_arch(best_hp["arch"])

model_final = DeepSurvMLP(in_dim, best_layers, dropout=best_hp["dropout"]).to(device)
opt_final = make_optimizer(model_final, lr=best_hp["lr"], wd=best_hp["wd"])
sched_final = torch.optim.lr_scheduler.CosineAnnealingLR(opt_final, T_max=best_n_epochs)

# --- Warmups (reuse constants from earlier cell if present; else define sane defaults) ---
try:
    WARMUP_EPOCHS_DROPOUT
except NameError:
    WARMUP_EPOCHS_DROPOUT = 30
    WARMUP_EPOCHS_WD = 30
    WARMUP_EPOCHS_L1 = 30
    DROPOUT_START = 0.15
    WD_START = 0.0

# --- Train ---
for epoch in range(best_n_epochs):
    frac_d = min(1.0, epoch / float(WARMUP_EPOCHS_DROPOUT))
    frac_w = min(1.0, epoch / float(WARMUP_EPOCHS_WD))
    set_dropout_p(model_final, DROPOUT_START + (best_hp['dropout'] - DROPOUT_START) * frac_d)
    set_weight_decay(opt_final, WD_START + (best_hp['wd'] - WD_START) * frac_w)

    stats = train_one_epoch(model_final, opt_final, trv_loader, device, l1_lambda=l1_final,
                            epoch=epoch, warmup_epochs=WARMUP_EPOCHS_L1)
    _ = full_risk_set_step(model_final, opt_final, trv_ds, device, l1_lambda=l1_final, warm=stats['warm'])
    sched_final.step()

# --- Evaluate ---
trainval_ci = evaluate_ci(model_final, trv_eval_loader, device)
test_ci = evaluate_ci(model_final, te_loader, device)
print(f"\n[Final] Train+Val CI: {trainval_ci:.4f}")
print(f"[Final] Test CI:      {test_ci:.4f}")


[Final] Train+Val CI: 0.7981
[Final] Test CI:      0.6103
