<a href="https://colab.research.google.com/github/osun24/nsclc-adj-chemo/blob/main/TorchSurv_DeepSurv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary packages
!pip install torchsurv scikit-survival

# Import required packages
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sksurv.metrics import concordance_index_censored

# (Optional) Mount Google Drive if you plan to load/save files there
from google.colab import drive
drive.mount('/content/drive')


Collecting torchsurv
  Downloading torchsurv-0.1.5-py3-none-any.whl.metadata (15 kB)
Collecting scikit-survival
  Downloading scikit_survival-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Collecting torchmetrics (from torchsurv)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting ecos (from scikit-survival)
  Downloading ecos-2.0.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.0 kB)
Collecting osqp<1.0.0,>=0.6.3 (from scikit-survival)
  Downloading osqp-0.6.7.post3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting qdldl (from osqp<1.0.0,>=0.6.3->scikit-survival)
  Downloading qdldl-0.1.7.post5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics->torchsurv)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloadin

In [None]:
import os
import sys
import math
import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.loss.cox import neg_partial_log_likelihood
from sksurv.metrics import concordance_index_censored
import warnings
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import copy
import random

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---------------- utils ----------------
class Tee:
    def __init__(self, *files): self.files = files
    def write(self, data):
        for f in self.files: f.write(data)
    def flush(self):
        for f in self.files: f.flush()

def loguniform(rng, lo, hi):
    return float(np.exp(rng.uniform(np.log(lo), np.log(hi))))

def sample_hparams(rng):
    hidden_options = [[8], [16], [32], [16, 8], [32, 16]]
    layers = hidden_options[rng.integers(len(hidden_options))]
    dropout = float(rng.uniform(0.5, 0.7))             # narrower, slightly higher
    lr = loguniform(rng, 3e-5, 1e-4)
    wd = loguniform(rng, 2e-3, 5e-1)                   # stronger L2 via AdamW
    l1 = loguniform(rng, 5e-5, 5e-3)                   # modest L1
    return {'layers': layers, 'dropout': dropout, 'lr': lr, 'wd': wd, 'l1': l1}

# ---------------- model & data ----------------
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.bool)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx]

def l1_penalty(model):
    return sum(p.abs().sum() for n,p in model.named_parameters()
               if p.requires_grad and p.dim() > 1)

def train_one_epoch(model, optimizer, dataloader, device, l1_lambda=0.0):
    model.train()
    for x, t, e in dataloader:
        if e.sum().item() == 0:  # skip non-informative batch
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)
        optimizer.zero_grad()
        out = torch.clamp(model(x), -20, 20)
        loss = neg_partial_log_likelihood(out, e, t, reduction='mean')
        if l1_lambda > 0: loss = loss + l1_lambda * l1_penalty(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    with torch.no_grad():
        for x, t, e in dataloader:
            x = x.to(device)
            y = torch.clamp(model(x), -20, 20)
            preds.append(y.cpu().numpy().ravel())
            times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    if np.isnan(preds).any():
        print("Warning: NaN predictions detected, returning -inf for concordance index")
        return -np.inf
    times = np.concatenate(times); events = np.concatenate(events)
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def count_params(in_dim, hidden_layers):
    params, d = 0, in_dim
    for h in hidden_layers:
        params += d*h + h  # weights + bias
        d = h
    params += d*1 + 1     # final layer
    return params

def collect_val_arrays(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    with torch.no_grad():
        for x, t, e in dataloader:
            x = x.to(device)
            y = torch.clamp(model(x), -20, 20).squeeze(1)
            preds.append(y.cpu().numpy())
            times.append(t.numpy()); events.append(e.numpy())
    return np.concatenate(preds), np.concatenate(times), np.concatenate(events)

def bootstrap_se_cindex(preds, times, events, B=200, seed=123):
    rng = np.random.default_rng(seed)
    n = len(times)
    vals = []
    for _ in range(B):
        idx = rng.integers(0, n, n)
        vals.append(concordance_index_censored(events[idx].astype(bool), times[idx], preds[idx])[0])
    return np.std(vals, ddof=1)

# ---------------- random search with successive halving ----------------
def main():
    # logging
    original_stdout = sys.stdout
    log_path = "/content/drive/MyDrive/deepsurv_8-28-25_training_log.txt"
    with open(log_path, "w") as log_file:
        sys.stdout = Tee(original_stdout, log_file)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        # ----- load & prep data -----
        train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
        valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")
        for df in (train_df, valid_df):
            if 'Adjuvant Chemo' in df.columns:
                df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        binary_columns = ['Adjuvant Chemo','IS_MALE']
        for col in binary_columns:
            if col in train_df.columns: train_df[col] = train_df[col].astype(int)
            if col in valid_df.columns: valid_df[col] = valid_df[col].astype(int)

        survival_cols = ['OS_STATUS','OS_MONTHS']
        feature_cols = [c for c in train_df.columns if c not in survival_cols]

        X_train = train_df[feature_cols].values.astype(np.float32)
        X_valid = valid_df[feature_cols].values.astype(np.float32)
        scaler = StandardScaler().fit(X_train)
        X_train = scaler.transform(X_train).astype(np.float32)
        X_valid = scaler.transform(X_valid).astype(np.float32)

        y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
        y_train_event = train_df['OS_STATUS'].values.astype(np.float32)
        y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
        y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

        train_ds = SurvivalDataset(X_train, y_train_time, y_train_event)
        valid_ds = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

        batch_size = 32
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  drop_last=True)
        train_eval_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
        valid_loader      = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)

        # ----- search budget & rungs -----
        rng = np.random.default_rng(42)
        num_trials = 80                     # budget (adjust as needed)
        rungs = [8, 24, 48, 96]                # epochs at which we prune
        eta = 3                            # keep top 1/eta each rung
        patience, min_delta = 12, 1e-4      # early stopping inside trials
        PRINT_EVERY = 1

        # ----- initialize trials -----
        trials = []
        for tid in range(num_trials):
            hp = sample_hparams(rng)
            model = DeepSurvMLP(X_train.shape[1], hp['layers'], dropout=hp['dropout']).to(device)
            optimizer = optim.AdamW(model.parameters(), lr=hp['lr'], weight_decay=hp['wd'])
            trials.append({
                'id': tid, 'hp': hp, 'model': model, 'opt': optimizer,
                'best_ci': -np.inf, 'best_state': copy.deepcopy(model.state_dict()),
                'hist_train_ci': [], 'hist_val_ci': [], 'epochs_done': 0,
                'alive': True, 'no_improve': 0, 'last_val_ci': -np.inf
            })

        # ----- successive halving loop -----
        for rung_idx, rung_ep in enumerate(rungs, start=1):
            print(f"\n=== Rung {rung_idx}/{len(rungs)} → target {rung_ep} epochs ===")
            # train all alive trials up to rung_ep (with early stopping)
            for tr in trials:
                if not tr['alive']: continue
                target = rung_ep
                while tr['epochs_done'] < target:
                    train_one_epoch(tr['model'], tr['opt'], train_loader, device, l1_lambda=tr['hp']['l1'])
                    tr['epochs_done'] += 1
                    # diagnostics each epoch
                    tr_ci = evaluate_ci(tr['model'], train_eval_loader, device)
                    va_ci = evaluate_ci(tr['model'], valid_loader, device)
                    tr['hist_train_ci'].append(tr_ci)
                    tr['hist_val_ci'].append(va_ci)

                    if tr['epochs_done'] % PRINT_EVERY == 0:
                      print(f"[Trial {tr['id']:02d} | Rung {rung_idx}/{len(rungs)} | "
                            f"Epoch {tr['epochs_done']:3d}] Train CI {tr_ci:.4f} | "
                            f"Val CI {va_ci:.4f} | Best {tr['best_ci']:.4f}")

                    # track best
                    if va_ci > tr['best_ci'] + min_delta:
                        tr['best_ci'] = va_ci
                        tr['best_state'] = copy.deepcopy(tr['model'].state_dict())
                        tr['no_improve'] = 0
                    else:
                        tr['no_improve'] += 1
                        if tr['no_improve'] >= patience:
                            print(f"Trial {tr['id']} early-stopped at epoch {tr['epochs_done']} (best Val CI={tr['best_ci']:.4f})")
                            break

            # prune: keep top 1/eta among alive
            alive = [tr for tr in trials if tr['alive']]
            alive.sort(key=lambda z: z['best_ci'], reverse=True)
            keep_n = max(1, math.ceil(len(alive) / eta))
            survivors = set(tr['id'] for tr in alive[:keep_n])

            print(f"Alive before prune: {len(alive)}; keeping top {keep_n}")
            for tr in alive:
                if tr['id'] not in survivors:
                    tr['alive'] = False
                    # free memory
                    del tr['model']; tr['model'] = None
                    tr['opt'] = None
                    if torch.cuda.is_available(): torch.cuda.empty_cache()

        # ----- 1-SE model selection -----
        raw_best = max(trials, key=lambda z: z['best_ci'])

        tmp_model = DeepSurvMLP(X_train.shape[1], raw_best['hp']['layers'],
                                dropout=raw_best['hp']['dropout']).to(device)
        tmp_model.load_state_dict(raw_best['best_state'])
        val_preds, val_times, val_events = collect_val_arrays(tmp_model, valid_loader, device)
        se_best = bootstrap_se_cindex(val_preds, val_times, val_events, B=200, seed=123)
        threshold = raw_best['best_ci'] - se_best

        candidates = [tr for tr in trials if tr['best_state'] is not None and tr['best_ci'] >= threshold]

        def simplicity_score(tr):
            # prefer fewer params; if tie, prefer larger regularization & dropout
            return (count_params(X_train.shape[1], tr['hp']['layers']),
                    -(tr['hp']['dropout']),
                    -(tr['hp']['wd'] + tr['hp']['l1']))

        selected = min(candidates, key=simplicity_score)

        best_trial = selected
        best_hp = best_trial['hp']
        print("\n[1-SE] Raw best CI: {:.4f}, SE: {:.4f}, threshold: {:.4f}".format(
              raw_best['best_ci'], se_best, threshold))
        print("[1-SE] Selected Hyperparameters (simplest within 1-SE):")
        print(best_hp)
        print("Selected Validation CI:", best_trial['best_ci'])

        # ----- save plots per trial -----
        out_date = current_date
        for tr in trials:
            cfg = f"trial{tr['id']}_layers-{'-'.join(map(str, tr['hp']['layers']))}_drop{tr['hp']['dropout']:.2f}_lr{tr['hp']['lr']:.2e}_wd{tr['hp']['wd']:.2e}_l1{tr['hp']['l1']:.2e}"
            epochs = range(1, len(tr['hist_train_ci'])+1)
            plt.figure()
            plt.plot(epochs, tr['hist_train_ci'], label='Train CI')
            plt.plot(epochs, tr['hist_val_ci'], label='Val CI')
            plt.xlabel('Epoch'); plt.ylabel('Concordance Index'); plt.legend(); plt.grid(True, alpha=0.3)
            plt.title(cfg)
            plt.ylim(0.4, 1.0)
            plot_path = os.path.join(output_dir, f"{out_date}_ci_{cfg}.png")
            plt.savefig(plot_path, dpi=150, bbox_inches='tight'); plt.close()
            print(f"Saved CI plot to {plot_path}")

        # ----- save results table -----
        results = []
        for tr in trials:
            row = {'trial_id': tr['id'], 'val_ci': tr['best_ci'], 'epochs_trained': len(tr['hist_val_ci']),
                   'alive_final': tr['alive']}
            row.update({ 'layers': '-'.join(map(str, tr['hp']['layers'])),
                         'dropout': tr['hp']['dropout'], 'lr': tr['hp']['lr'],
                         'weight_decay(L2)': tr['hp']['wd'], 'l1_lambda': tr['hp']['l1']})
            results.append(row)
        df = pd.DataFrame(results).sort_values('val_ci', ascending=False)
        csv_path = os.path.join(output_dir, f"{out_date}_deepsurv_randomSH_results.csv")
        df.to_csv(csv_path, index=False)
        print(f"Hyperparameter search results saved to {csv_path}")

        # ----- save best model -----
        best_model_path = os.path.join(output_dir, f"{out_date}_best_deepsurv_model.pth")
        torch.save(best_trial['best_state'], best_model_path)
        print(f"Best model saved to {best_model_path}")

        # ----- test evaluation -----
        test_df = pd.read_csv("/content/drive/MyDrive/affyTest.csv")
        if 'Adjuvant Chemo' in test_df.columns:
            test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        for col in binary_columns:
            if col in test_df.columns: test_df[col] = test_df[col].astype(int)
        X_test = scaler.transform(test_df[feature_cols].values.astype(np.float32)).astype(np.float32)
        y_test_time = test_df['OS_MONTHS'].values.astype(np.float32)
        y_test_event = test_df['OS_STATUS'].values.astype(np.float32)

        test_loader = DataLoader(SurvivalDataset(X_test, y_test_time, y_test_event),
                                 batch_size=batch_size, shuffle=False)

        # rebuild & load best
        final_model = DeepSurvMLP(X_train.shape[1], best_hp['layers'], dropout=best_hp['dropout']).to(device)
        final_model.load_state_dict(torch.load(best_model_path, map_location=device))
        test_ci = evaluate_ci(final_model, test_loader, device)
        print(f"Test CI: {test_ci:.4f}")

        sys.stdout.flush()

    sys.stdout = original_stdout
    print("Training completed. Check your log file at:", log_path)

if __name__ == "__main__":
    main()

  df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})



=== Rung 1/4 → target 8 epochs ===
[Trial 00 | Rung 1/4 | Epoch   1] Train CI 0.6955 | Val CI 0.5622 | Best -inf
[Trial 00 | Rung 1/4 | Epoch   2] Train CI 0.7655 | Val CI 0.5921 | Best 0.5622
[Trial 00 | Rung 1/4 | Epoch   3] Train CI 0.8029 | Val CI 0.6285 | Best 0.5921
[Trial 00 | Rung 1/4 | Epoch   4] Train CI 0.8350 | Val CI 0.6225 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   5] Train CI 0.8495 | Val CI 0.6243 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   6] Train CI 0.8713 | Val CI 0.6256 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   7] Train CI 0.8788 | Val CI 0.6494 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   8] Train CI 0.8979 | Val CI 0.6210 | Best 0.6494
[Trial 01 | Rung 1/4 | Epoch   1] Train CI 0.6131 | Val CI 0.5905 | Best -inf
[Trial 01 | Rung 1/4 | Epoch   2] Train CI 0.6578 | Val CI 0.6106 | Best 0.5905
[Trial 01 | Rung 1/4 | Epoch   3] Train CI 0.6653 | Val CI 0.6018 | Best 0.6106
[Trial 01 | Rung 1/4 | Epoch   4] Train CI 0.6599 | Val CI 0.6329 | Best 0.6106
[Trial 0

  test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})


Test CI: 0.6206
Training completed. Check your log file at: /content/drive/MyDrive/deepsurv_8-28-25_training_log.txt


In [None]:
import os
import sys
import math
import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.loss.cox import neg_partial_log_likelihood
from sksurv.metrics import concordance_index_censored
import warnings
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import copy
import random

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---------------- utils ----------------
class Tee:
    def __init__(self, *files): self.files = files
    def write(self, data):
        for f in self.files: f.write(data)
    def flush(self):
        for f in self.files: f.flush()

def loguniform(rng, lo, hi):
    return float(np.exp(rng.uniform(np.log(lo), np.log(hi))))

def sample_hparams(rng):
    # Architectures: keep the same small winners in play
    hidden_options = [[16], [32], [32, 16], [16,8,4], [32,16,8], [64]]
    layers = hidden_options[rng.integers(len(hidden_options))]

    # More dropout (but not extreme)
    dropout = float(rng.uniform(0.5, 0.65))

    # Slightly wider LR band (strong reg can tolerate a touch more LR)
    lr = loguniform(rng, 2e-5, 1.5e-4)

    r = rng.random()
    if r < 0.20:
        wd = float(rng.choice(wd_anchors))
        l1 = float(rng.choice(l1_anchors))
    elif r < 0.80:
        # heavy
        wd = loguniform(rng, 5e-2, 3e-1)   # 0.05–0.30 (match the comment)
        l1 = loguniform(rng, 7e-4, 5e-3)   # 7e-4–5e-3
    else:
        # extreme
        wd = loguniform(rng, 3e-1, 6e-1)   # 0.30–0.60
        l1 = loguniform(rng, 5e-3, 2e-2)   # 5e-3–2e-2

    return {'layers': layers, 'dropout': dropout, 'lr': lr, 'wd': wd, 'l1': l1}

# ---------------- model & data ----------------
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.bool)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx]

# ---- Param groups: L2 only on non-bias, non-final weights ----
def make_optimizer(model, lr, wd):
    # find last Linear
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None

    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        # No WD on biases
        if name.endswith('bias'):
            no_decay.append(p)
            continue
        # No WD on final layer weight (keep ranker free)
        if (last_linear is not None) and (p is last_linear.weight):
            no_decay.append(p)
            continue
        # Everything else decays (hidden & input weights)
        decay.append(p)

    param_groups = [
        {'params': decay, 'weight_decay': wd},
        {'params': no_decay, 'weight_decay': 0.0},
    ]
    return optim.AdamW(param_groups, lr=lr)

# L1 ONLY on the first (input) Linear layer
def l1_penalty_first_layer(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            return m.weight.abs().sum()
    return torch.tensor(0.0, device=next(model.parameters()).device)

def train_one_epoch(model, optimizer, dataloader, device, l1_lambda=0.0, epoch=0, warmup_epochs=20):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))  # linear warmup of L1
    loss_sum, n_seen = 0.0, 0
    skipped, total_batches = 0, 0
    for x, t, e in dataloader:
        total_batches += 1
        if e.sum().item() == 0:  # skip non-informative batch for Cox
            skipped += 1
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)
        optimizer.zero_grad()
        out = torch.clamp(model(x), -20, 20)
        loss = neg_partial_log_likelihood(out, e, t, reduction='mean')
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        n_seen += x.size(0)
    avg_loss = loss_sum / max(n_seen, 1)
    skip_frac = skipped / max(total_batches, 1)
    return {'avg_loss': avg_loss, 'skip_frac': skip_frac}

def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    with torch.no_grad():
        for x, t, e in dataloader:
            x = x.to(device)
            y = torch.clamp(model(x), -20, 20)
            preds.append(y.cpu().numpy().ravel())
            times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    if np.isnan(preds).any():
        print("Warning: NaN predictions detected, returning -inf for concordance index")
        return -np.inf
    times = np.concatenate(times); events = np.concatenate(events)
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def count_params(in_dim, hidden_layers):
    params, d = 0, in_dim
    for h in hidden_layers:
        params += d*h + h  # weights + bias
        d = h
    params += d*1 + 1     # final layer
    return params

# ---------------- random search with successive halving ----------------
def main():
    # logging
    original_stdout = sys.stdout
    log_path = "/content/drive/MyDrive/deepsurv_8-30-25_training_log.txt"
    with open(log_path, "w") as log_file:
        sys.stdout = Tee(original_stdout, log_file)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        # ----- load & prep data -----
        train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
        valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")
        for df in (train_df, valid_df):
            if 'Adjuvant Chemo' in df.columns:
                df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        binary_columns = ['Adjuvant Chemo','IS_MALE']
        for col in binary_columns:
            if col in train_df.columns: train_df[col] = train_df[col].astype(int)
            if col in valid_df.columns: valid_df[col] = valid_df[col].astype(int)

        survival_cols = ['OS_STATUS','OS_MONTHS']
        feature_cols = [c for c in train_df.columns if c not in survival_cols]

        X_train = train_df[feature_cols].values.astype(np.float32)
        X_valid = valid_df[feature_cols].values.astype(np.float32)
        scaler = StandardScaler().fit(X_train)
        X_train = scaler.transform(X_train).astype(np.float32)
        X_valid = scaler.transform(X_valid).astype(np.float32)

        y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
        y_train_event = train_df['OS_STATUS'].values.astype(np.float32)
        y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
        y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

        train_ds = SurvivalDataset(X_train, y_train_time, y_train_event)
        valid_ds = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

        BATCH_SIZE = 64  # slightly larger to stabilize gradients under strong reg
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=True)
        train_eval_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)
        valid_loader      = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)

        # ----- search budget & rungs -----
        rng = np.random.default_rng(42)
        num_trials = 100
        rungs = [16, 64, 160, 320]  # longer rungs for heavy reg
        eta = 3
        min_delta = 1e-4
        PRINT_EVERY = 1
        L1_WARMUP_EPOCHS = 30

        # ----- initialize trials -----
        trials = []
        for tid in range(num_trials):
            hp = sample_hparams(rng)
            model = DeepSurvMLP(X_train.shape[1], hp['layers'], dropout=hp['dropout']).to(device)
            optimizer = make_optimizer(model, lr=hp['lr'], wd=hp['wd'])
            trials.append({
                'id': tid, 'hp': hp, 'model': model, 'opt': optimizer, 'sched': None, 'sched_target': 0,
                'best_ci': -np.inf, 'best_state': copy.deepcopy(model.state_dict()),
                'hist_train_ci': [], 'hist_val_ci': [], 'hist_loss': [], 'hist_skip': [],
                'epochs_done': 0, 'alive': True, 'no_improve': 0
            })

        # ----- successive halving loop -----
        prev_target = 0
        for rung_idx, rung_ep in enumerate(rungs, start=1):
            print(f"\n=== Rung {rung_idx}/{len(rungs)} → target {rung_ep} epochs ===")

            # rung-relative patience (no early stop in rung 1)
            span = rung_ep - prev_target
            es_patience = float('inf') if rung_idx == 1 else max(10, int(0.25 * span))

            # train all alive trials up to rung_ep
            for tr in trials:
                if not tr['alive']: continue

                # (re)create cosine scheduler for this rung
                steps_remaining = rung_ep - tr['epochs_done']
                if steps_remaining <= 0:
                    continue
                tr['sched'] = torch.optim.lr_scheduler.CosineAnnealingLR(tr['opt'], T_max=steps_remaining)
                tr['sched_target'] = rung_ep

                while tr['epochs_done'] < rung_ep:
                    stats = train_one_epoch(
                        tr['model'], tr['opt'], train_loader, device,
                        l1_lambda=tr['hp']['l1'], epoch=tr['epochs_done'], warmup_epochs=L1_WARMUP_EPOCHS
                    )
                    tr['epochs_done'] += 1

                    # diagnostics each epoch
                    tr_ci = evaluate_ci(tr['model'], train_eval_loader, device)
                    va_ci = evaluate_ci(tr['model'], valid_loader, device)
                    tr['hist_train_ci'].append(tr_ci)
                    tr['hist_val_ci'].append(va_ci)
                    tr['hist_loss'].append(stats['avg_loss'])
                    tr['hist_skip'].append(stats['skip_frac'])

                    if tr['sched'] is not None:
                        tr['sched'].step()

                    if tr['epochs_done'] % PRINT_EVERY == 0:
                        print(f"[Trial {tr['id']:02d} | Rung {rung_idx}/{len(rungs)} | "
                              f"Epoch {tr['epochs_done']:3d}] "
                              f"Loss {stats['avg_loss']:.4f} | Skip% {100*stats['skip_frac']:.1f} | "
                              f"Train CI {tr_ci:.4f} | Val CI {va_ci:.4f} | Best {tr['best_ci']:.4f}")

                    # track best
                    if va_ci > tr['best_ci'] + min_delta:
                        tr['best_ci'] = va_ci
                        tr['best_state'] = copy.deepcopy(tr['model'].state_dict())
                        tr['no_improve'] = 0
                    else:
                        tr['no_improve'] += 1
                        if tr['no_improve'] >= es_patience:
                            print(f"Trial {tr['id']} early-stopped at epoch {tr['epochs_done']} "
                                  f"(best Val CI={tr['best_ci']:.4f})")
                            break

            # prune: keep top 1/eta among alive
            alive = [tr for tr in trials if tr['alive']]
            alive.sort(key=lambda z: z['best_ci'], reverse=True)
            keep_n = max(1, math.ceil(len(alive) / eta))
            survivors = set(tr['id'] for tr in alive[:keep_n])

            print(f"Alive before prune: {len(alive)}; keeping top {keep_n}")
            for tr in alive:
                if tr['id'] not in survivors:
                    tr['alive'] = False
                    # free memory
                    del tr['model']; tr['model'] = None
                    tr['opt'] = None; tr['sched'] = None
                    if torch.cuda.is_available(): torch.cuda.empty_cache()

            prev_target = rung_ep

        # ----- BEST model selection (no 1-SE) -----
        best_trial = max(trials, key=lambda z: z['best_ci'])
        best_hp = best_trial['hp']
        print("\nBest Hyperparameters (by Val CI):")
        print(best_hp)
        print("Best Validation CI:", best_trial['best_ci'])

        # ----- save plots per trial -----
        out_date = current_date
        for tr in trials:
            cfg = (f"trial{tr['id']}_layers-{'-'.join(map(str, tr['hp']['layers']))}"
                   f"_drop{tr['hp']['dropout']:.2f}_lr{tr['hp']['lr']:.2e}"
                   f"_wd{tr['hp']['wd']:.2e}_l1{tr['hp']['l1']:.2e}")
            epochs = range(1, len(tr['hist_train_ci'])+1)
            plt.figure()
            plt.plot(epochs, tr['hist_train_ci'], label='Train CI')
            plt.plot(epochs, tr['hist_val_ci'], label='Val CI')
            plt.xlabel('Epoch'); plt.ylabel('Concordance Index'); plt.legend(); plt.grid(True, alpha=0.3)
            plt.title(cfg)
            plt.ylim(0.4, 1.0)  # constant y-scale across plots
            plot_path = os.path.join(output_dir, f"{out_date}_ci_{cfg}.png")
            plt.savefig(plot_path, dpi=150, bbox_inches='tight'); plt.close()
            print(f"Saved CI plot to {plot_path}")

        # ----- save results table -----
        results = []
        for tr in trials:
            row = {
                'trial_id': tr['id'], 'val_ci': tr['best_ci'],
                'epochs_trained': len(tr['hist_val_ci']), 'alive_final': tr['alive'],
                'avg_loss_last': tr['hist_loss'][-1] if tr['hist_loss'] else np.nan,
                'skip_frac_last': tr['hist_skip'][-1] if tr['hist_skip'] else np.nan
            }
            row.update({
                'layers': '-'.join(map(str, tr['hp']['layers'])),
                'dropout': tr['hp']['dropout'], 'lr': tr['hp']['lr'],
                'weight_decay(L2)': tr['hp']['wd'], 'l1_lambda': tr['hp']['l1'],
                'param_count': count_params(X_train.shape[1], tr['hp']['layers'])
            })
            results.append(row)
        df = pd.DataFrame(results).sort_values('val_ci', ascending=False)
        csv_path = os.path.join(output_dir, f"{out_date}_deepsurv_randomSH_results.csv")
        df.to_csv(csv_path, index=False)
        print(f"Hyperparameter search results saved to {csv_path}")

        # ----- save best model -----
        best_model_path = os.path.join(output_dir, f"{out_date}_best_deepsurv_model.pth")
        torch.save(best_trial['best_state'], best_model_path)
        print(f"Best model saved to {best_model_path}")

        # ----- test evaluation -----
        test_df = pd.read_csv("/content/drive/MyDrive/affyTest.csv")
        if 'Adjuvant Chemo' in test_df.columns:
            test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        for col in binary_columns:
            if col in test_df.columns: test_df[col] = test_df[col].astype(int)
        X_test = scaler.transform(test_df[feature_cols].values.astype(np.float32)).astype(np.float32)
        y_test_time = test_df['OS_MONTHS'].values.astype(np.float32)
        y_test_event = test_df['OS_STATUS'].values.astype(np.float32)

        test_loader = DataLoader(SurvivalDataset(X_test, y_test_time, y_test_event),
                                 batch_size=BATCH_SIZE, shuffle=False)

        # rebuild & load best
        final_model = DeepSurvMLP(X_train.shape[1], best_hp['layers'], dropout=best_hp['dropout']).to(device)
        final_model.load_state_dict(torch.load(best_model_path, map_location=device))
        test_ci = evaluate_ci(final_model, test_loader, device)
        print(f"Test CI: {test_ci:.4f}")

        sys.stdout.flush()

    sys.stdout = original_stdout
    print("Training completed. Check your log file at:", log_path)

if __name__ == "__main__":
    main()

  df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})



=== Rung 1/4 → target 16 epochs ===
[Trial 00 | Rung 1/4 | Epoch   1] Loss 3.7278 | Skip% 0.0 | Train CI 0.6437 | Val CI 0.5776 | Best -inf
[Trial 00 | Rung 1/4 | Epoch   2] Loss 3.7901 | Skip% 0.0 | Train CI 0.6940 | Val CI 0.6023 | Best 0.5776
[Trial 00 | Rung 1/4 | Epoch   3] Loss 3.8766 | Skip% 0.0 | Train CI 0.7375 | Val CI 0.6054 | Best 0.6023
[Trial 00 | Rung 1/4 | Epoch   4] Loss 3.9967 | Skip% 0.0 | Train CI 0.7716 | Val CI 0.6048 | Best 0.6054
[Trial 00 | Rung 1/4 | Epoch   5] Loss 4.0509 | Skip% 0.0 | Train CI 0.7838 | Val CI 0.6051 | Best 0.6054
[Trial 00 | Rung 1/4 | Epoch   6] Loss 4.1414 | Skip% 0.0 | Train CI 0.8000 | Val CI 0.6098 | Best 0.6054
[Trial 00 | Rung 1/4 | Epoch   7] Loss 4.2232 | Skip% 0.0 | Train CI 0.8148 | Val CI 0.6089 | Best 0.6098
[Trial 00 | Rung 1/4 | Epoch   8] Loss 4.3393 | Skip% 0.0 | Train CI 0.8233 | Val CI 0.6140 | Best 0.6098
[Trial 00 | Rung 1/4 | Epoch   9] Loss 4.4331 | Skip% 0.0 | Train CI 0.8310 | Val CI 0.6167 | Best 0.6140
[Trial 00 |

  test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})


Test CI: 0.6642
Training completed. Check your log file at: /content/drive/MyDrive/deepsurv_8-30-25_training_log.txt


In [2]:
import os
import sys
import math
import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torchsurv.loss.cox import neg_partial_log_likelihood
from sksurv.metrics import concordance_index_censored
import warnings
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import copy
import random

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---------------- utils ----------------
class Tee:
    def __init__(self, *files): self.files = files
    def write(self, data):
        for f in self.files: f.write(data)
    def flush(self):
        for f in self.files: f.flush()

def loguniform(rng, lo, hi):
    return float(np.exp(rng.uniform(np.log(lo), np.log(hi))))

# Anchors so we can occasionally pick exact values (includes 1e-3 for L1)
wd_anchors = np.array([5e-2, 1e-1, 1.5e-1, 2e-1, 3e-1, 5e-1])
l1_anchors = np.array([5e-4, 7e-4, 1e-3, 2e-3, 5e-3, 1e-2])

def sample_hparams(rng):
    # Architectures: keep the same small winners in play
    hidden_options = [[256], [512], [256, 512], [32, 16]]
    layers = hidden_options[rng.integers(len(hidden_options))]

    # Moderately high dropout
    dropout = float(rng.uniform(0.30, 0.65))

    # Slightly wider LR band (strong reg can handle a touch more LR)
    lr = loguniform(rng, 2e-5, 1.5e-4)

    # Mix anchors + ranges
    r = rng.random()
    if r < 0.20:
        wd = float(rng.choice(wd_anchors))
        l1 = float(rng.choice(l1_anchors))
    elif r < 0.80:
        # heavy
        wd = loguniform(rng, 5e-2, 3e-1)    # 0.05–0.30
        l1 = loguniform(rng, 7e-4, 5e-3)    # 7e-4–5e-3
    else:
        # extreme
        wd = loguniform(rng, 3e-1, 6e-1)    # 0.30–0.60
        l1 = loguniform(rng, 5e-3, 2e-2)    # 5e-3–2e-2

    return {'layers': layers, 'dropout': dropout, 'lr': lr, 'wd': wd, 'l1': l1}

# ---------------- model & data ----------------
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.bool)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx]

# ---- (1) Event-balanced batch sampler: guarantee ≥1 event per batch ----
class EventBalancedBatchSampler(Sampler):
    def __init__(self, events_numpy, batch_size, seed=0):
        events = np.asarray(events_numpy).astype(bool)
        self.pos_idx = np.where(events)[0]
        self.neg_idx = np.where(~events)[0]
        assert len(self.pos_idx) > 0, "No events in training set — cannot balance batches."
        self.bs = int(batch_size)
        self.rng = np.random.default_rng(seed)

    def __iter__(self):
        # one finite pass == one epoch
        pos = self.rng.permutation(self.pos_idx)
        neg = self.rng.permutation(self.neg_idx)
        n_total = len(pos) + len(neg)
        n_batches = math.ceil(n_total / self.bs)

        pi = ni = 0
        for _ in range(n_batches):
            take_pos = 1 if pi < len(pos) else 0
            # fill the rest with negatives if we can
            avail_neg = max(0, len(neg) - ni)
            take_neg = min(self.bs - take_pos, avail_neg)

            # if we ran out of negatives, top up with extra positives
            need = self.bs - (take_pos + take_neg)
            extra_pos = min(need, max(0, len(pos) - (pi + take_pos)))
            take_pos += extra_pos

            batch = np.concatenate([
                pos[pi:pi+take_pos],
                neg[ni:ni+take_neg]
            ])
            pi += take_pos
            ni += take_neg

            if batch.size == 0:
                break
            self.rng.shuffle(batch)
            yield batch.tolist()

    def __len__(self):
        return math.ceil((len(self.pos_idx) + len(self.neg_idx)) / self.bs)

# ---- Param groups: L2 only on non-bias, non-final weights ----
def make_optimizer(model, lr, wd):
    # find last Linear
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None

    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.endswith('bias'):
            no_decay.append(p); continue
        if (last_linear is not None) and (p is last_linear.weight):
            no_decay.append(p); continue
        decay.append(p)

    param_groups = [
        {'params': decay, 'weight_decay': wd},
        {'params': no_decay, 'weight_decay': 0.0},
    ]
    return optim.AdamW(param_groups, lr=lr)

# ---- (2) Reg warm-up helpers (dropout + WD) ----
def set_dropout_p(model, p):
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.p = float(p)

def set_weight_decay(optimizer, wd):
    for g in optimizer.param_groups:
        g['weight_decay'] = float(wd)

# L1 ONLY on the first (input) Linear layer (already warmed up in train loop)
def l1_penalty_first_layer(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            return m.weight.abs().sum()
    return torch.tensor(0.0, device=next(model.parameters()).device)

def train_one_epoch(model, optimizer, dataloader, device, l1_lambda=0.0, epoch=0, warmup_epochs=20):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))  # linear warmup of L1
    loss_sum, n_seen = 0.0, 0
    skipped, total_batches = 0, 0
    for x, t, e in dataloader:
        total_batches += 1
        # with balanced sampler, this should almost never trigger, but keep it safe:
        if e.sum().item() == 0:
            skipped += 1
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)
        optimizer.zero_grad(set_to_none=True)
        out = torch.clamp(model(x), -20, 20)
        loss = neg_partial_log_likelihood(out, e, t, reduction='mean')
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        n_seen += x.size(0)
    avg_loss = loss_sum / max(n_seen, 1)
    skip_frac = skipped / max(total_batches, 1)
    return {'avg_loss': avg_loss, 'skip_frac': skip_frac, 'warm': warm}

# ---- (3) Full-risk-set correction step (1x per epoch) ----
def full_risk_set_step(model, optimizer, train_ds, device, l1_lambda=0.0, warm=1.0):
    model.train()
    X_all = train_ds.x.to(device)
    t_all = train_ds.time.to(device)
    e_all = train_ds.event.to(device)
    optimizer.zero_grad(set_to_none=True)
    out_all = torch.clamp(model(X_all), -20, 20)
    loss_full = neg_partial_log_likelihood(out_all, e_all, t_all, reduction='mean')
    if l1_lambda > 0:
        loss_full = loss_full + (l1_lambda * warm) * l1_penalty_first_layer(model)
    loss_full.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
    optimizer.step()
    return float(loss_full.detach().cpu().item())

def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    with torch.no_grad():
        for x, t, e in dataloader:
            x = x.to(device)
            y = torch.clamp(model(x), -20, 20)
            preds.append(y.cpu().numpy().ravel())
            times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    if np.isnan(preds).any():
        print("Warning: NaN predictions detected, returning -inf for concordance index")
        return -np.inf
    times = np.concatenate(times); events = np.concatenate(events)
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def count_params(in_dim, hidden_layers):
    params, d = 0, in_dim
    for h in hidden_layers:
        params += d*h + h  # weights + bias
        d = h
    params += d*1 + 1     # final layer
    return params

# ---------------- random search with successive halving ----------------
def main():
    # logging
    original_stdout = sys.stdout
    log_path = "/content/drive/MyDrive/deepsurv_8-30-25_training_log.txt"
    with open(log_path, "w") as log_file:
        sys.stdout = Tee(original_stdout, log_file)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        # ----- load & prep data -----
        train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
        valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")
        for df in (train_df, valid_df):
            if 'Adjuvant Chemo' in df.columns:
                df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        binary_columns = ['Adjuvant Chemo','IS_MALE']
        for col in binary_columns:
            if col in train_df.columns: train_df[col] = train_df[col].astype(int)
            if col in valid_df.columns: valid_df[col] = valid_df[col].astype(int)

        survival_cols = ['OS_STATUS','OS_MONTHS']
        feature_cols = [c for c in train_df.columns if c not in survival_cols]

        X_train = train_df[feature_cols].values.astype(np.float32)
        X_valid = valid_df[feature_cols].values.astype(np.float32)
        scaler = StandardScaler().fit(X_train)
        X_train = scaler.transform(X_train).astype(np.float32)
        X_valid = scaler.transform(X_valid).astype(np.float32)

        y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
        y_train_event = train_df['OS_STATUS'].values.astype(np.float32)
        y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
        y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

        train_ds = SurvivalDataset(X_train, y_train_time, y_train_event)
        valid_ds = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

        BATCH_SIZE = 64

        # (1) swap in the event-balanced sampler
        train_sampler = EventBalancedBatchSampler(y_train_event, BATCH_SIZE, seed=42)
        train_loader  = DataLoader(train_ds, batch_sampler=train_sampler)
        # evaluation loaders remain standard
        train_eval_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)
        valid_loader      = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)

        # ----- search budget & rungs -----
        rng = np.random.default_rng(42)
        num_trials = 100
        rungs = [16, 64, 160, 320]  # keep your longer rungs
        eta = 3
        min_delta = 1e-4
        PRINT_EVERY = 1

        # Warm-ups
        L1_WARMUP_EPOCHS = 30
        WD_WARMUP_EPOCHS = 30
        DROPOUT_WARMUP_EPOCHS = 30
        DROPOUT_START = 0.15
        WD_START = 0.0

        # ΔVal-CI moving-average window
        DELTA_CI_MA_K = 10

        # ----- initialize trials -----
        trials = []
        for tid in range(num_trials):
            hp = sample_hparams(rng)
            model = DeepSurvMLP(X_train.shape[1], hp['layers'], dropout=hp['dropout']).to(device)
            optimizer = make_optimizer(model, lr=hp['lr'], wd=hp['wd'])
            trials.append({
                'id': tid, 'hp': hp, 'model': model, 'opt': optimizer, 'sched': None, 'sched_target': 0,
                'best_ci': -np.inf, 'best_state': copy.deepcopy(model.state_dict()),
                'hist_train_ci': [], 'hist_val_ci': [], 'hist_loss': [], 'hist_skip': [],
                'hist_val_ci_delta': [], 'hist_val_ci_delta_ma': [],
                'epochs_done': 0, 'alive': True, 'no_improve': 0
            })

        # ----- successive halving loop -----
        prev_target = 0
        for rung_idx, rung_ep in enumerate(rungs, start=1):
            print(f"\n=== Rung {rung_idx}/{len(rungs)} → target {rung_ep} epochs ===")

            span = rung_ep - prev_target
            es_patience = float('inf') if rung_idx == 1 else max(10, int(0.25 * span))

            for tr in trials:
                if not tr['alive']: continue
                steps_remaining = rung_ep - tr['epochs_done']
                if steps_remaining <= 0:
                    continue

                tr['sched'] = torch.optim.lr_scheduler.CosineAnnealingLR(tr['opt'], T_max=steps_remaining)
                tr['sched_target'] = rung_ep

                while tr['epochs_done'] < rung_ep:
                    # (2) apply dropout + weight-decay warm-up schedules
                    frac_d = min(1.0, tr['epochs_done'] / float(DROPOUT_WARMUP_EPOCHS))
                    frac_w = min(1.0, tr['epochs_done'] / float(WD_WARMUP_EPOCHS))
                    p_t = DROPOUT_START + (tr['hp']['dropout'] - DROPOUT_START) * frac_d
                    wd_t = WD_START + (tr['hp']['wd'] - WD_START) * frac_w
                    set_dropout_p(tr['model'], p_t)
                    set_weight_decay(tr['opt'], wd_t)

                    # normal mini-batch training epoch
                    stats = train_one_epoch(
                        tr['model'], tr['opt'], train_loader, device,
                        l1_lambda=tr['hp']['l1'], epoch=tr['epochs_done'], warmup_epochs=L1_WARMUP_EPOCHS
                    )
                    # (3) one full-risk-set correction step per epoch
                    full_loss = full_risk_set_step(
                        tr['model'], tr['opt'], train_ds, device,
                        l1_lambda=tr['hp']['l1'], warm=stats['warm']
                    )

                    tr['epochs_done'] += 1

                    # diagnostics each epoch
                    tr_ci = evaluate_ci(tr['model'], train_eval_loader, device)
                    va_ci = evaluate_ci(tr['model'], valid_loader, device)
                    tr['hist_train_ci'].append(tr_ci)
                    tr['hist_val_ci'].append(va_ci)
                    tr['hist_loss'].append(stats['avg_loss'])
                    tr['hist_skip'].append(stats['skip_frac'])

                    # ΔVal-CI and its moving average
                    if len(tr['hist_val_ci']) >= 2:
                        delta = tr['hist_val_ci'][-1] - tr['hist_val_ci'][-2]
                    else:
                        delta = 0.0
                    tr['hist_val_ci_delta'].append(delta)
                    ma = float(np.mean(tr['hist_val_ci_delta'][-DELTA_CI_MA_K:]))
                    tr['hist_val_ci_delta_ma'].append(ma)

                    if tr['sched'] is not None:
                        tr['sched'].step()

                    if tr['epochs_done'] % PRINT_EVERY == 0:
                        print(f"[Trial {tr['id']:02d} | Rung {rung_idx}/{len(rungs)} | "
                              f"Epoch {tr['epochs_done']:3d}] "
                              f"Loss {stats['avg_loss']:.4f} | FullStepLoss {full_loss:.4f} | "
                              f"Skip% {100*stats['skip_frac']:.1f} | "
                              f"Train CI {tr_ci:.4f} | Val CI {va_ci:.4f} | "
                              f"ΔVal-CI MA({DELTA_CI_MA_K}) {ma:+.4f} | Best {tr['best_ci']:.4f}")

                    # track best
                    if va_ci > tr['best_ci'] + min_delta:
                        tr['best_ci'] = va_ci
                        tr['best_state'] = copy.deepcopy(tr['model'].state_dict())
                        tr['no_improve'] = 0
                    else:
                        tr['no_improve'] += 1
                        if tr['no_improve'] >= es_patience:
                            print(f"Trial {tr['id']} early-stopped at epoch {tr['epochs_done']} "
                                  f"(best Val CI={tr['best_ci']:.4f})")
                            break

            # prune: keep top 1/eta among alive
            alive = [tr for tr in trials if tr['alive']]
            alive.sort(key=lambda z: z['best_ci'], reverse=True)
            keep_n = max(1, math.ceil(len(alive) / eta))
            survivors = set(tr['id'] for tr in alive[:keep_n])

            print(f"Alive before prune: {len(alive)}; keeping top {keep_n}")
            for tr in alive:
                if tr['id'] not in survivors:
                    tr['alive'] = False
                    # free memory
                    del tr['model']; tr['model'] = None
                    tr['opt'] = None; tr['sched'] = None
                    if torch.cuda.is_available(): torch.cuda.empty_cache()

            prev_target = rung_ep

        # ----- BEST model selection (no 1-SE) -----
        best_trial = max(trials, key=lambda z: z['best_ci'])
        best_hp = best_trial['hp']
        print("\nBest Hyperparameters (by Val CI):")
        print(best_hp)
        print("Best Validation CI:", best_trial['best_ci'])

        # ----- save plots per trial -----
        out_date = current_date
        for tr in trials:
            cfg = (f"trial{tr['id']}_layers-{'-'.join(map(str, tr['hp']['layers']))}"
                   f"_drop{tr['hp']['dropout']:.2f}_lr{tr['hp']['lr']:.2e}"
                   f"_wd{tr['hp']['wd']:.2e}_l1{tr['hp']['l1']:.2e}")
            epochs = range(1, len(tr['hist_train_ci'])+1)
            plt.figure()
            plt.plot(epochs, tr['hist_train_ci'], label='Train CI')
            plt.plot(epochs, tr['hist_val_ci'], label='Val CI')
            # Optional: plot ΔVal-CI MA on a secondary axis as trend
            ax = plt.gca()
            ax2 = ax.twinx()
            ax2.plot(epochs, tr['hist_val_ci_delta_ma'], linestyle='--', alpha=0.5, label=f'ΔVal-CI MA')
            ax.set_xlabel('Epoch'); ax.set_ylabel('Concordance Index')
            ax2.set_ylabel('ΔVal-CI MA')
            lines, labels = ax.get_legend_handles_labels()
            lines2, labels2 = ax2.get_legend_handles_labels()
            ax.legend(lines + lines2, labels + labels2, loc='lower right')
            plt.grid(True, alpha=0.3); plt.title(cfg)
            ax.set_ylim(0.4, 1.0)  # constant y-scale across plots
            plot_path = os.path.join(output_dir, f"{out_date}_ci_{cfg}.png")
            plt.savefig(plot_path, dpi=150, bbox_inches='tight'); plt.close()
            print(f"Saved CI plot to {plot_path}")

        # ----- save results table -----
        results = []
        for tr in trials:
            row = {
                'trial_id': tr['id'],
                'val_ci': tr['best_ci'],
                'epochs_trained': len(tr['hist_val_ci']),
                'alive_final': tr['alive'],
                'avg_loss_last': tr['hist_loss'][-1] if tr['hist_loss'] else np.nan,
                'skip_frac_last': tr['hist_skip'][-1] if tr['hist_skip'] else np.nan,
                'val_ci_ma10_last': tr['hist_val_ci_delta_ma'][-1] if tr['hist_val_ci_delta_ma'] else np.nan
            }
            row.update({
                'layers': '-'.join(map(str, tr['hp']['layers'])),
                'dropout': tr['hp']['dropout'], 'lr': tr['hp']['lr'],
                'weight_decay(L2)': tr['hp']['wd'], 'l1_lambda': tr['hp']['l1'],
                'param_count': count_params(X_train.shape[1], tr['hp']['layers'])
            })
            results.append(row)
        df = pd.DataFrame(results).sort_values('val_ci', ascending=False)
        csv_path = os.path.join(output_dir, f"{out_date}_deepsurv_randomSH_results.csv")
        df.to_csv(csv_path, index=False)
        print(f"Hyperparameter search results saved to {csv_path}")

        # ----- save best model -----
        best_model_path = os.path.join(output_dir, f"{out_date}_best_deepsurv_model.pth")
        torch.save(best_trial['best_state'], best_model_path)
        print(f"Best model saved to {best_model_path}")

        # ----- test evaluation -----
        test_df = pd.read_csv("/content/drive/MyDrive/affyTest.csv")
        if 'Adjuvant Chemo' in test_df.columns:
            test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        for col in binary_columns:
            if col in test_df.columns: test_df[col] = test_df[col].astype(int)
        X_test = scaler.transform(test_df[feature_cols].values.astype(np.float32)).astype(np.float32)
        y_test_time = test_df['OS_MONTHS'].values.astype(np.float32)
        y_test_event = test_df['OS_STATUS'].values.astype(np.float32)

        test_loader = DataLoader(SurvivalDataset(X_test, y_test_time, y_test_event),
                                 batch_size=BATCH_SIZE, shuffle=False)

        # rebuild & load best
        final_model = DeepSurvMLP(X_train.shape[1], best_hp['layers'], dropout=best_hp['dropout']).to(device)
        final_model.load_state_dict(torch.load(best_model_path, map_location=device))
        test_ci = evaluate_ci(final_model, test_loader, device)
        print(f"Test CI: {test_ci:.4f}")

        sys.stdout.flush()

    sys.stdout = original_stdout
    print("Training completed. Check your log file at:", log_path)

if __name__ == "__main__":
    main()

  df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})



=== Rung 1/4 → target 16 epochs ===
[Trial 00 | Rung 1/4 | Epoch   1] Loss 6.0549 | FullStepLoss 8.6995 | Skip% 0.0 | Train CI 0.7223 | Val CI 0.6209 | ΔVal-CI MA(10) +0.0000 | Best -inf
[Trial 00 | Rung 1/4 | Epoch   2] Loss 8.1990 | FullStepLoss 10.6327 | Skip% 0.0 | Train CI 0.7910 | Val CI 0.6280 | ΔVal-CI MA(10) +0.0035 | Best 0.6209
[Trial 00 | Rung 1/4 | Epoch   3] Loss 9.7839 | FullStepLoss 12.4337 | Skip% 0.0 | Train CI 0.8407 | Val CI 0.6168 | ΔVal-CI MA(10) -0.0014 | Best 0.6280
[Trial 00 | Rung 1/4 | Epoch   4] Loss 11.6345 | FullStepLoss 14.0757 | Skip% 0.0 | Train CI 0.8760 | Val CI 0.6210 | ΔVal-CI MA(10) +0.0000 | Best 0.6280
[Trial 00 | Rung 1/4 | Epoch   5] Loss 12.8480 | FullStepLoss 15.6341 | Skip% 0.0 | Train CI 0.8530 | Val CI 0.6243 | ΔVal-CI MA(10) +0.0007 | Best 0.6280
[Trial 00 | Rung 1/4 | Epoch   6] Loss 14.7410 | FullStepLoss 16.8849 | Skip% 0.0 | Train CI 0.8974 | Val CI 0.5843 | ΔVal-CI MA(10) -0.0061 | Best 0.6280
[Trial 00 | Rung 1/4 | Epoch   7] Loss 

  test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})


Test CI: 0.6358
Training completed. Check your log file at: /content/drive/MyDrive/deepsurv_8-30-25_training_log.txt
