<a href="https://colab.research.google.com/github/osun24/nsclc-adj-chemo/blob/main/TorchSurv_DeepSurv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary packages
!pip install torchsurv scikit-survival

# Import required packages
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sksurv.metrics import concordance_index_censored

# (Optional) Mount Google Drive if you plan to load/save files there
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import sys
import math
import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.loss.cox import neg_partial_log_likelihood
from sksurv.metrics import concordance_index_censored
import warnings
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import copy
import random

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---------------- utils ----------------
class Tee:
    def __init__(self, *files): self.files = files
    def write(self, data):
        for f in self.files: f.write(data)
    def flush(self):
        for f in self.files: f.flush()

def loguniform(rng, lo, hi):
    return float(np.exp(rng.uniform(np.log(lo), np.log(hi))))

def sample_hparams(rng):
    hidden_options = [[8], [16], [32], [16, 8], [32, 16]]
    layers = hidden_options[rng.integers(len(hidden_options))]
    dropout = float(rng.uniform(0.4, 0.7))             # narrower, slightly higher
    lr = loguniform(rng, 2e-5, 2e-4)                   # lower LR band
    wd = loguniform(rng, 2e-3, 5e-1)                   # stronger L2 via AdamW
    l1 = loguniform(rng, 5e-5, 5e-3)                   # modest L1
    return {'layers': layers, 'dropout': dropout, 'lr': lr, 'wd': wd, 'l1': l1}

# ---------------- model & data ----------------
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.bool)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx]

def l1_penalty(model):
    return sum(p.abs().sum() for n,p in model.named_parameters()
               if p.requires_grad and p.dim() > 1)

def train_one_epoch(model, optimizer, dataloader, device, l1_lambda=0.0):
    model.train()
    for x, t, e in dataloader:
        if e.sum().item() == 0:  # skip non-informative batch
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)
        optimizer.zero_grad()
        out = torch.clamp(model(x), -20, 20)
        loss = neg_partial_log_likelihood(out, e, t, reduction='mean')
        if l1_lambda > 0: loss = loss + l1_lambda * l1_penalty(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    with torch.no_grad():
        for x, t, e in dataloader:
            x = x.to(device)
            y = torch.clamp(model(x), -20, 20)
            preds.append(y.cpu().numpy().ravel())
            times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    if np.isnan(preds).any():
        print("Warning: NaN predictions detected, returning -inf for concordance index")
        return -np.inf
    times = np.concatenate(times); events = np.concatenate(events)
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def count_params(in_dim, hidden_layers):
    params, d = 0, in_dim
    for h in hidden_layers:
        params += d*h + h  # weights + bias
        d = h
    params += d*1 + 1     # final layer
    return params

def collect_val_arrays(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    with torch.no_grad():
        for x, t, e in dataloader:
            x = x.to(device)
            y = torch.clamp(model(x), -20, 20).squeeze(1)
            preds.append(y.cpu().numpy())
            times.append(t.numpy()); events.append(e.numpy())
    return np.concatenate(preds), np.concatenate(times), np.concatenate(events)

def bootstrap_se_cindex(preds, times, events, B=200, seed=123):
    rng = np.random.default_rng(seed)
    n = len(times)
    vals = []
    for _ in range(B):
        idx = rng.integers(0, n, n)
        vals.append(concordance_index_censored(events[idx].astype(bool), times[idx], preds[idx])[0])
    return np.std(vals, ddof=1)

# ---------------- random search with successive halving ----------------
def main():
    # logging
    original_stdout = sys.stdout
    log_path = "/content/drive/MyDrive/deepsurv_8-28-25_training_log.txt"
    with open(log_path, "w") as log_file:
        sys.stdout = Tee(original_stdout, log_file)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        # ----- load & prep data -----
        train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
        valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")
        for df in (train_df, valid_df):
            if 'Adjuvant Chemo' in df.columns:
                df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        binary_columns = ['Adjuvant Chemo','IS_MALE']
        for col in binary_columns:
            if col in train_df.columns: train_df[col] = train_df[col].astype(int)
            if col in valid_df.columns: valid_df[col] = valid_df[col].astype(int)

        survival_cols = ['OS_STATUS','OS_MONTHS']
        feature_cols = [c for c in train_df.columns if c not in survival_cols]

        X_train = train_df[feature_cols].values.astype(np.float32)
        X_valid = valid_df[feature_cols].values.astype(np.float32)
        scaler = StandardScaler().fit(X_train)
        X_train = scaler.transform(X_train).astype(np.float32)
        X_valid = scaler.transform(X_valid).astype(np.float32)

        y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
        y_train_event = train_df['OS_STATUS'].values.astype(np.float32)
        y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
        y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

        train_ds = SurvivalDataset(X_train, y_train_time, y_train_event)
        valid_ds = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

        batch_size = 32
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  drop_last=True)
        train_eval_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
        valid_loader      = DataLoader(valid_ds, batch_size=batch_size, shuffle=False)

        # ----- search budget & rungs -----
        rng = np.random.default_rng(42)
        num_trials = 80                     # budget (adjust as needed)
        rungs = [8, 24, 48, 96]                # epochs at which we prune
        eta = 3                            # keep top 1/eta each rung
        patience, min_delta = 12, 1e-4      # early stopping inside trials
        PRINT_EVERY = 1

        # ----- initialize trials -----
        trials = []
        for tid in range(num_trials):
            hp = sample_hparams(rng)
            model = DeepSurvMLP(X_train.shape[1], hp['layers'], dropout=hp['dropout']).to(device)
            optimizer = optim.AdamW(model.parameters(), lr=hp['lr'], weight_decay=hp['wd'])
            trials.append({
                'id': tid, 'hp': hp, 'model': model, 'opt': optimizer,
                'best_ci': -np.inf, 'best_state': copy.deepcopy(model.state_dict()),
                'hist_train_ci': [], 'hist_val_ci': [], 'epochs_done': 0,
                'alive': True, 'no_improve': 0, 'last_val_ci': -np.inf
            })

        # ----- successive halving loop -----
        for rung_idx, rung_ep in enumerate(rungs, start=1):
            print(f"\n=== Rung {rung_idx}/{len(rungs)} → target {rung_ep} epochs ===")
            # train all alive trials up to rung_ep (with early stopping)
            for tr in trials:
                if not tr['alive']: continue
                target = rung_ep
                while tr['epochs_done'] < target:
                    train_one_epoch(tr['model'], tr['opt'], train_loader, device, l1_lambda=tr['hp']['l1'])
                    tr['epochs_done'] += 1
                    # diagnostics each epoch
                    tr_ci = evaluate_ci(tr['model'], train_eval_loader, device)
                    va_ci = evaluate_ci(tr['model'], valid_loader, device)
                    tr['hist_train_ci'].append(tr_ci)
                    tr['hist_val_ci'].append(va_ci)

                    if tr['epochs_done'] % PRINT_EVERY == 0:
                      print(f"[Trial {tr['id']:02d} | Rung {rung_idx}/{len(rungs)} | "
                            f"Epoch {tr['epochs_done']:3d}] Train CI {tr_ci:.4f} | "
                            f"Val CI {va_ci:.4f} | Best {tr['best_ci']:.4f}")

                    # track best
                    if va_ci > tr['best_ci'] + min_delta:
                        tr['best_ci'] = va_ci
                        tr['best_state'] = copy.deepcopy(tr['model'].state_dict())
                        tr['no_improve'] = 0
                    else:
                        tr['no_improve'] += 1
                        if tr['no_improve'] >= patience:
                            print(f"Trial {tr['id']} early-stopped at epoch {tr['epochs_done']} (best Val CI={tr['best_ci']:.4f})")
                            break

            # prune: keep top 1/eta among alive
            alive = [tr for tr in trials if tr['alive']]
            alive.sort(key=lambda z: z['best_ci'], reverse=True)
            keep_n = max(1, math.ceil(len(alive) / eta))
            survivors = set(tr['id'] for tr in alive[:keep_n])

            print(f"Alive before prune: {len(alive)}; keeping top {keep_n}")
            for tr in alive:
                if tr['id'] not in survivors:
                    tr['alive'] = False
                    # free memory
                    del tr['model']; tr['model'] = None
                    tr['opt'] = None
                    if torch.cuda.is_available(): torch.cuda.empty_cache()

        # ----- 1-SE model selection -----
        raw_best = max(trials, key=lambda z: z['best_ci'])

        tmp_model = DeepSurvMLP(X_train.shape[1], raw_best['hp']['layers'],
                                dropout=raw_best['hp']['dropout']).to(device)
        tmp_model.load_state_dict(raw_best['best_state'])
        val_preds, val_times, val_events = collect_val_arrays(tmp_model, valid_loader, device)
        se_best = bootstrap_se_cindex(val_preds, val_times, val_events, B=200, seed=123)
        threshold = raw_best['best_ci'] - se_best

        candidates = [tr for tr in trials if tr['best_state'] is not None and tr['best_ci'] >= threshold]

        def simplicity_score(tr):
            # prefer fewer params; if tie, prefer larger regularization & dropout
            return (count_params(X_train.shape[1], tr['hp']['layers']),
                    -(tr['hp']['dropout']),
                    -(tr['hp']['wd'] + tr['hp']['l1']))

        selected = min(candidates, key=simplicity_score)

        best_trial = selected
        best_hp = best_trial['hp']
        print("\n[1-SE] Raw best CI: {:.4f}, SE: {:.4f}, threshold: {:.4f}".format(
              raw_best['best_ci'], se_best, threshold))
        print("[1-SE] Selected Hyperparameters (simplest within 1-SE):")
        print(best_hp)
        print("Selected Validation CI:", best_trial['best_ci'])

        # ----- save plots per trial -----
        out_date = current_date
        for tr in trials:
            cfg = f"trial{tr['id']}_layers-{'-'.join(map(str, tr['hp']['layers']))}_drop{tr['hp']['dropout']:.2f}_lr{tr['hp']['lr']:.2e}_wd{tr['hp']['wd']:.2e}_l1{tr['hp']['l1']:.2e}"
            epochs = range(1, len(tr['hist_train_ci'])+1)
            plt.figure()
            plt.plot(epochs, tr['hist_train_ci'], label='Train CI')
            plt.plot(epochs, tr['hist_val_ci'], label='Val CI')
            plt.xlabel('Epoch'); plt.ylabel('Concordance Index'); plt.legend(); plt.grid(True, alpha=0.3)
            plt.title(cfg)
            plt.ylim(0.4, 1.0)
            plot_path = os.path.join(output_dir, f"{out_date}_ci_{cfg}.png")
            plt.savefig(plot_path, dpi=150, bbox_inches='tight'); plt.close()
            print(f"Saved CI plot to {plot_path}")

        # ----- save results table -----
        results = []
        for tr in trials:
            row = {'trial_id': tr['id'], 'val_ci': tr['best_ci'], 'epochs_trained': len(tr['hist_val_ci']),
                   'alive_final': tr['alive']}
            row.update({ 'layers': '-'.join(map(str, tr['hp']['layers'])),
                         'dropout': tr['hp']['dropout'], 'lr': tr['hp']['lr'],
                         'weight_decay(L2)': tr['hp']['wd'], 'l1_lambda': tr['hp']['l1']})
            results.append(row)
        df = pd.DataFrame(results).sort_values('val_ci', ascending=False)
        csv_path = os.path.join(output_dir, f"{out_date}_deepsurv_randomSH_results.csv")
        df.to_csv(csv_path, index=False)
        print(f"Hyperparameter search results saved to {csv_path}")

        # ----- save best model -----
        best_model_path = os.path.join(output_dir, f"{out_date}_best_deepsurv_model.pth")
        torch.save(best_trial['best_state'], best_model_path)
        print(f"Best model saved to {best_model_path}")

        # ----- test evaluation -----
        test_df = pd.read_csv("/content/drive/MyDrive/affyTest.csv")
        if 'Adjuvant Chemo' in test_df.columns:
            test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        for col in binary_columns:
            if col in test_df.columns: test_df[col] = test_df[col].astype(int)
        X_test = scaler.transform(test_df[feature_cols].values.astype(np.float32)).astype(np.float32)
        y_test_time = test_df['OS_MONTHS'].values.astype(np.float32)
        y_test_event = test_df['OS_STATUS'].values.astype(np.float32)

        test_loader = DataLoader(SurvivalDataset(X_test, y_test_time, y_test_event),
                                 batch_size=batch_size, shuffle=False)

        # rebuild & load best
        final_model = DeepSurvMLP(X_train.shape[1], best_hp['layers'], dropout=best_hp['dropout']).to(device)
        final_model.load_state_dict(torch.load(best_model_path, map_location=device))
        test_ci = evaluate_ci(final_model, test_loader, device)
        print(f"Test CI: {test_ci:.4f}")

        sys.stdout.flush()

    sys.stdout = original_stdout
    print("Training completed. Check your log file at:", log_path)

if __name__ == "__main__":
    main()

  df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})



=== Rung 1/4 → target 8 epochs ===
[Trial 00 | Rung 1/4 | Epoch   1] Train CI 0.6955 | Val CI 0.5622 | Best -inf
[Trial 00 | Rung 1/4 | Epoch   2] Train CI 0.7655 | Val CI 0.5921 | Best 0.5622
[Trial 00 | Rung 1/4 | Epoch   3] Train CI 0.8029 | Val CI 0.6285 | Best 0.5921
[Trial 00 | Rung 1/4 | Epoch   4] Train CI 0.8350 | Val CI 0.6225 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   5] Train CI 0.8495 | Val CI 0.6243 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   6] Train CI 0.8713 | Val CI 0.6256 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   7] Train CI 0.8788 | Val CI 0.6494 | Best 0.6285
[Trial 00 | Rung 1/4 | Epoch   8] Train CI 0.8979 | Val CI 0.6210 | Best 0.6494
[Trial 01 | Rung 1/4 | Epoch   1] Train CI 0.6131 | Val CI 0.5905 | Best -inf
[Trial 01 | Rung 1/4 | Epoch   2] Train CI 0.6578 | Val CI 0.6106 | Best 0.5905
[Trial 01 | Rung 1/4 | Epoch   3] Train CI 0.6653 | Val CI 0.6018 | Best 0.6106
[Trial 01 | Rung 1/4 | Epoch   4] Train CI 0.6599 | Val CI 0.6329 | Best 0.6106
[Trial 0

  test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})


Test CI: 0.6206
Training completed. Check your log file at: /content/drive/MyDrive/deepsurv_8-28-25_training_log.txt


In [5]:
import os
import sys
import math
import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.loss.cox import neg_partial_log_likelihood
from sksurv.metrics import concordance_index_censored
import warnings
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import copy
import random

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")
torch.manual_seed(0); np.random.seed(0); random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---------------- utils ----------------
class Tee:
    def __init__(self, *files): self.files = files
    def write(self, data):
        for f in self.files: f.write(data)
    def flush(self):
        for f in self.files: f.flush()

def loguniform(rng, lo, hi):
    return float(np.exp(rng.uniform(np.log(lo), np.log(hi))))

def sample_hparams(rng):
    # Architectures: keep the same small winners in play
    hidden_options = [[16], [32], [32, 16], [16,8,4], [32,16,8]]
    layers = hidden_options[rng.integers(len(hidden_options))]

    # More dropout (but not extreme)
    dropout = float(rng.uniform(0.45, 0.70))

    # Slightly wider LR band (strong reg can tolerate a touch more LR)
    lr = loguniform(rng, 1e-5, 7e-5)

    # Heavier L2/L1: sample from a "heavy" band 70% of the time,
    # and an "extreme" band 30% of the time to explore the frontier.
    if rng.random() < 0.7:
        # heavy
        wd = loguniform(rng, 5e-2, 3e-1)     # 0.05 – 0.30
        l1 = loguniform(rng, 7e-4, 5e-3)     # 7e-4 – 5e-3
    else:
        # extreme
        wd = loguniform(rng, 3e-1, 6e-1)     # 0.30 – 0.60
        l1 = loguniform(rng, 5e-3, 2e-2)     # 5e-3 – 2e-2

    return {'layers': layers, 'dropout': dropout, 'lr': lr, 'wd': wd, 'l1': l1}

# ---------------- model & data ----------------
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers, d = [], in_features
        for units in hidden_layers:
            layers += [nn.Linear(d, units), activation]
            if dropout > 0: layers.append(nn.Dropout(dropout))
            d = units
        layers.append(nn.Linear(d, 1))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.bool)
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx], self.time[idx], self.event[idx]

# ---- Param groups: L2 only on non-bias, non-final weights ----
def make_optimizer(model, lr, wd):
    # find last Linear
    linears = [m for m in model.modules() if isinstance(m, nn.Linear)]
    last_linear = linears[-1] if len(linears) > 0 else None

    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        # No WD on biases
        if name.endswith('bias'):
            no_decay.append(p)
            continue
        # No WD on final layer weight (keep ranker free)
        if (last_linear is not None) and (p is last_linear.weight):
            no_decay.append(p)
            continue
        # Everything else decays (hidden & input weights)
        decay.append(p)

    param_groups = [
        {'params': decay, 'weight_decay': wd},
        {'params': no_decay, 'weight_decay': 0.0},
    ]
    return optim.AdamW(param_groups, lr=lr)

# L1 ONLY on the first (input) Linear layer
def l1_penalty_first_layer(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            return m.weight.abs().sum()
    return torch.tensor(0.0, device=next(model.parameters()).device)

def train_one_epoch(model, optimizer, dataloader, device, l1_lambda=0.0, epoch=0, warmup_epochs=20):
    model.train()
    warm = min(1.0, (epoch + 1) / float(warmup_epochs))  # linear warmup of L1
    loss_sum, n_seen = 0.0, 0
    skipped, total_batches = 0, 0
    for x, t, e in dataloader:
        total_batches += 1
        if e.sum().item() == 0:  # skip non-informative batch for Cox
            skipped += 1
            continue
        x, t, e = x.to(device), t.to(device), e.to(device)
        optimizer.zero_grad()
        out = torch.clamp(model(x), -20, 20)
        loss = neg_partial_log_likelihood(out, e, t, reduction='mean')
        if l1_lambda > 0:
            loss = loss + (l1_lambda * warm) * l1_penalty_first_layer(model)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        n_seen += x.size(0)
    avg_loss = loss_sum / max(n_seen, 1)
    skip_frac = skipped / max(total_batches, 1)
    return {'avg_loss': avg_loss, 'skip_frac': skip_frac}

def evaluate_ci(model, dataloader, device):
    model.eval()
    preds, times, events = [], [], []
    with torch.no_grad():
        for x, t, e in dataloader:
            x = x.to(device)
            y = torch.clamp(model(x), -20, 20)
            preds.append(y.cpu().numpy().ravel())
            times.append(t.numpy()); events.append(e.numpy())
    preds = np.concatenate(preds)
    if np.isnan(preds).any():
        print("Warning: NaN predictions detected, returning -inf for concordance index")
        return -np.inf
    times = np.concatenate(times); events = np.concatenate(events)
    return concordance_index_censored(events.astype(bool), times, preds)[0]

def count_params(in_dim, hidden_layers):
    params, d = 0, in_dim
    for h in hidden_layers:
        params += d*h + h  # weights + bias
        d = h
    params += d*1 + 1     # final layer
    return params

# ---------------- random search with successive halving ----------------
def main():
    # logging
    original_stdout = sys.stdout
    log_path = "/content/drive/MyDrive/deepsurv_8-30-25_training_log.txt"
    with open(log_path, "w") as log_file:
        sys.stdout = Tee(original_stdout, log_file)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        # ----- load & prep data -----
        train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
        valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")
        for df in (train_df, valid_df):
            if 'Adjuvant Chemo' in df.columns:
                df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        binary_columns = ['Adjuvant Chemo','IS_MALE']
        for col in binary_columns:
            if col in train_df.columns: train_df[col] = train_df[col].astype(int)
            if col in valid_df.columns: valid_df[col] = valid_df[col].astype(int)

        survival_cols = ['OS_STATUS','OS_MONTHS']
        feature_cols = [c for c in train_df.columns if c not in survival_cols]

        X_train = train_df[feature_cols].values.astype(np.float32)
        X_valid = valid_df[feature_cols].values.astype(np.float32)
        scaler = StandardScaler().fit(X_train)
        X_train = scaler.transform(X_train).astype(np.float32)
        X_valid = scaler.transform(X_valid).astype(np.float32)

        y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
        y_train_event = train_df['OS_STATUS'].values.astype(np.float32)
        y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
        y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

        train_ds = SurvivalDataset(X_train, y_train_time, y_train_event)
        valid_ds = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

        BATCH_SIZE = 64  # slightly larger to stabilize gradients under strong reg
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=True)
        train_eval_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)
        valid_loader      = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)

        # ----- search budget & rungs -----
        rng = np.random.default_rng(42)
        num_trials = 80
        rungs = [16, 64, 160, 320]  # longer rungs for heavy reg
        eta = 3
        min_delta = 1e-4
        PRINT_EVERY = 1
        L1_WARMUP_EPOCHS = 30

        # ----- initialize trials -----
        trials = []
        for tid in range(num_trials):
            hp = sample_hparams(rng)
            model = DeepSurvMLP(X_train.shape[1], hp['layers'], dropout=hp['dropout']).to(device)
            optimizer = make_optimizer(model, lr=hp['lr'], wd=hp['wd'])
            trials.append({
                'id': tid, 'hp': hp, 'model': model, 'opt': optimizer, 'sched': None, 'sched_target': 0,
                'best_ci': -np.inf, 'best_state': copy.deepcopy(model.state_dict()),
                'hist_train_ci': [], 'hist_val_ci': [], 'hist_loss': [], 'hist_skip': [],
                'epochs_done': 0, 'alive': True, 'no_improve': 0
            })

        # ----- successive halving loop -----
        prev_target = 0
        for rung_idx, rung_ep in enumerate(rungs, start=1):
            print(f"\n=== Rung {rung_idx}/{len(rungs)} → target {rung_ep} epochs ===")

            # rung-relative patience (no early stop in rung 1)
            span = rung_ep - prev_target
            es_patience = float('inf') if rung_idx == 1 else max(10, int(0.25 * span))

            # train all alive trials up to rung_ep
            for tr in trials:
                if not tr['alive']: continue

                # (re)create cosine scheduler for this rung
                steps_remaining = rung_ep - tr['epochs_done']
                if steps_remaining <= 0:
                    continue
                tr['sched'] = torch.optim.lr_scheduler.CosineAnnealingLR(tr['opt'], T_max=steps_remaining)
                tr['sched_target'] = rung_ep

                while tr['epochs_done'] < rung_ep:
                    stats = train_one_epoch(
                        tr['model'], tr['opt'], train_loader, device,
                        l1_lambda=tr['hp']['l1'], epoch=tr['epochs_done'], warmup_epochs=L1_WARMUP_EPOCHS
                    )
                    tr['epochs_done'] += 1

                    # diagnostics each epoch
                    tr_ci = evaluate_ci(tr['model'], train_eval_loader, device)
                    va_ci = evaluate_ci(tr['model'], valid_loader, device)
                    tr['hist_train_ci'].append(tr_ci)
                    tr['hist_val_ci'].append(va_ci)
                    tr['hist_loss'].append(stats['avg_loss'])
                    tr['hist_skip'].append(stats['skip_frac'])

                    if tr['sched'] is not None:
                        tr['sched'].step()

                    if tr['epochs_done'] % PRINT_EVERY == 0:
                        print(f"[Trial {tr['id']:02d} | Rung {rung_idx}/{len(rungs)} | "
                              f"Epoch {tr['epochs_done']:3d}] "
                              f"Loss {stats['avg_loss']:.4f} | Skip% {100*stats['skip_frac']:.1f} | "
                              f"Train CI {tr_ci:.4f} | Val CI {va_ci:.4f} | Best {tr['best_ci']:.4f}")

                    # track best
                    if va_ci > tr['best_ci'] + min_delta:
                        tr['best_ci'] = va_ci
                        tr['best_state'] = copy.deepcopy(tr['model'].state_dict())
                        tr['no_improve'] = 0
                    else:
                        tr['no_improve'] += 1
                        if tr['no_improve'] >= es_patience:
                            print(f"Trial {tr['id']} early-stopped at epoch {tr['epochs_done']} "
                                  f"(best Val CI={tr['best_ci']:.4f})")
                            break

            # prune: keep top 1/eta among alive
            alive = [tr for tr in trials if tr['alive']]
            alive.sort(key=lambda z: z['best_ci'], reverse=True)
            keep_n = max(1, math.ceil(len(alive) / eta))
            survivors = set(tr['id'] for tr in alive[:keep_n])

            print(f"Alive before prune: {len(alive)}; keeping top {keep_n}")
            for tr in alive:
                if tr['id'] not in survivors:
                    tr['alive'] = False
                    # free memory
                    del tr['model']; tr['model'] = None
                    tr['opt'] = None; tr['sched'] = None
                    if torch.cuda.is_available(): torch.cuda.empty_cache()

            prev_target = rung_ep

        # ----- BEST model selection (no 1-SE) -----
        best_trial = max(trials, key=lambda z: z['best_ci'])
        best_hp = best_trial['hp']
        print("\nBest Hyperparameters (by Val CI):")
        print(best_hp)
        print("Best Validation CI:", best_trial['best_ci'])

        # ----- save plots per trial -----
        out_date = current_date
        for tr in trials:
            cfg = (f"trial{tr['id']}_layers-{'-'.join(map(str, tr['hp']['layers']))}"
                   f"_drop{tr['hp']['dropout']:.2f}_lr{tr['hp']['lr']:.2e}"
                   f"_wd{tr['hp']['wd']:.2e}_l1{tr['hp']['l1']:.2e}")
            epochs = range(1, len(tr['hist_train_ci'])+1)
            plt.figure()
            plt.plot(epochs, tr['hist_train_ci'], label='Train CI')
            plt.plot(epochs, tr['hist_val_ci'], label='Val CI')
            plt.xlabel('Epoch'); plt.ylabel('Concordance Index'); plt.legend(); plt.grid(True, alpha=0.3)
            plt.title(cfg)
            plt.ylim(0.4, 1.0)  # constant y-scale across plots
            plot_path = os.path.join(output_dir, f"{out_date}_ci_{cfg}.png")
            plt.savefig(plot_path, dpi=150, bbox_inches='tight'); plt.close()
            print(f"Saved CI plot to {plot_path}")

        # ----- save results table -----
        results = []
        for tr in trials:
            row = {
                'trial_id': tr['id'], 'val_ci': tr['best_ci'],
                'epochs_trained': len(tr['hist_val_ci']), 'alive_final': tr['alive'],
                'avg_loss_last': tr['hist_loss'][-1] if tr['hist_loss'] else np.nan,
                'skip_frac_last': tr['hist_skip'][-1] if tr['hist_skip'] else np.nan
            }
            row.update({
                'layers': '-'.join(map(str, tr['hp']['layers'])),
                'dropout': tr['hp']['dropout'], 'lr': tr['hp']['lr'],
                'weight_decay(L2)': tr['hp']['wd'], 'l1_lambda': tr['hp']['l1'],
                'param_count': count_params(X_train.shape[1], tr['hp']['layers'])
            })
            results.append(row)
        df = pd.DataFrame(results).sort_values('val_ci', ascending=False)
        csv_path = os.path.join(output_dir, f"{out_date}_deepsurv_randomSH_results.csv")
        df.to_csv(csv_path, index=False)
        print(f"Hyperparameter search results saved to {csv_path}")

        # ----- save best model -----
        best_model_path = os.path.join(output_dir, f"{out_date}_best_deepsurv_model.pth")
        torch.save(best_trial['best_state'], best_model_path)
        print(f"Best model saved to {best_model_path}")

        # ----- test evaluation -----
        test_df = pd.read_csv("/content/drive/MyDrive/affyTest.csv")
        if 'Adjuvant Chemo' in test_df.columns:
            test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})
        for col in binary_columns:
            if col in test_df.columns: test_df[col] = test_df[col].astype(int)
        X_test = scaler.transform(test_df[feature_cols].values.astype(np.float32)).astype(np.float32)
        y_test_time = test_df['OS_MONTHS'].values.astype(np.float32)
        y_test_event = test_df['OS_STATUS'].values.astype(np.float32)

        test_loader = DataLoader(SurvivalDataset(X_test, y_test_time, y_test_event),
                                 batch_size=BATCH_SIZE, shuffle=False)

        # rebuild & load best
        final_model = DeepSurvMLP(X_train.shape[1], best_hp['layers'], dropout=best_hp['dropout']).to(device)
        final_model.load_state_dict(torch.load(best_model_path, map_location=device))
        test_ci = evaluate_ci(final_model, test_loader, device)
        print(f"Test CI: {test_ci:.4f}")

        sys.stdout.flush()

    sys.stdout = original_stdout
    print("Training completed. Check your log file at:", log_path)

if __name__ == "__main__":
    main()

  df['Adjuvant Chemo'] = df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})



=== Rung 1/4 → target 16 epochs ===
[Trial 00 | Rung 1/4 | Epoch   1] Loss 3.7183 | Skip% 0.0 | Train CI 0.6179 | Val CI 0.5997 | Best -inf
[Trial 00 | Rung 1/4 | Epoch   2] Loss 3.7331 | Skip% 0.0 | Train CI 0.6699 | Val CI 0.6002 | Best 0.5997
[Trial 00 | Rung 1/4 | Epoch   3] Loss 3.8147 | Skip% 0.0 | Train CI 0.7094 | Val CI 0.6050 | Best 0.6002
[Trial 00 | Rung 1/4 | Epoch   4] Loss 3.9042 | Skip% 0.0 | Train CI 0.7343 | Val CI 0.6191 | Best 0.6050
[Trial 00 | Rung 1/4 | Epoch   5] Loss 3.9888 | Skip% 0.0 | Train CI 0.7465 | Val CI 0.6145 | Best 0.6191
[Trial 00 | Rung 1/4 | Epoch   6] Loss 4.0517 | Skip% 0.0 | Train CI 0.7726 | Val CI 0.6142 | Best 0.6191
[Trial 00 | Rung 1/4 | Epoch   7] Loss 4.0993 | Skip% 0.0 | Train CI 0.7917 | Val CI 0.6136 | Best 0.6191
[Trial 00 | Rung 1/4 | Epoch   8] Loss 4.1668 | Skip% 0.0 | Train CI 0.7968 | Val CI 0.6163 | Best 0.6191
[Trial 00 | Rung 1/4 | Epoch   9] Loss 4.2214 | Skip% 0.0 | Train CI 0.7901 | Val CI 0.6164 | Best 0.6191
[Trial 00 |

  test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS':0,'ACT':1})


Test CI: 0.6445
Training completed. Check your log file at: /content/drive/MyDrive/deepsurv_8-28-25_training_log.txt
