<a href="https://colab.research.google.com/github/osun24/nsclc-adj-chemo/blob/main/TorchSurv_DeepSurv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary packages
!pip install torchsurv scikit-survival

# Import required packages
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sksurv.metrics import concordance_index_censored

# (Optional) Mount Google Drive if you plan to load/save files there
from google.colab import drive
drive.mount('/content/drive')


Collecting torchsurv
  Downloading torchsurv-0.1.5-py3-none-any.whl.metadata (15 kB)
Collecting scikit-survival
  Downloading scikit_survival-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Collecting torchmetrics (from torchsurv)
  Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB)
Collecting ecos (from scikit-survival)
  Downloading ecos-2.0.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.0 kB)
Collecting osqp<1.0.0,>=0.6.3 (from scikit-survival)
  Downloading osqp-0.6.7.post3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting qdldl (from osqp<1.0.0,>=0.6.3->scikit-survival)
  Downloading qdldl-0.1.7.post5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics->torchsurv)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloadin

In [None]:
import os
import sys
import time
import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.loss.cox import neg_partial_log_likelihood
from sksurv.metrics import concordance_index_censored
import warnings
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")

# Define a Tee class for logging output to both console and file
class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, data):
        for f in self.files:
            f.write(data)
    def flush(self):
        for f in self.files:
            f.flush()

# Define a custom MLP model for DeepSurv (remove BatchNorm)
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super().__init__()
        layers = []
        current_dim = in_features
        for units in hidden_layers:
            layers += [nn.Linear(current_dim, units), activation]
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            current_dim = units
        layers.append(nn.Linear(current_dim, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Define a PyTorch Dataset for survival data
class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.bool)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.time[idx], self.event[idx]

def train_model(model, optimizer, dataloader, device):
    model.train()
    running_loss, n_seen = 0.0, 0
    for x, time_vals, events in dataloader:
        if events.sum().item() == 0:
            continue  # skip non-informative batches for Cox
        x = x.to(device); time_vals = time_vals.to(device); events = events.to(device)
        optimizer.zero_grad()
        outputs = torch.clamp(model(x), -20, 20)
        loss = neg_partial_log_likelihood(outputs, events, time_vals, reduction='mean')
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        running_loss += loss.item() * x.size(0)
        n_seen += x.size(0)
    return running_loss / max(n_seen, 1)

def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_time = []
    all_event = []
    with torch.no_grad():
        for x, time_vals, events in dataloader:
            x = x.to(device)
            preds = torch.clamp(model(x), -20, 20) # clamp to prevent blowups
            all_preds.append(preds.cpu().numpy().flatten())
            all_time.append(time_vals.cpu().numpy())
            all_event.append(events.cpu().numpy())
    all_preds = np.concatenate(all_preds)
    if np.isnan(all_preds).any():
        print("Warning: NaN predictions detected, returning -inf for concordance index")
        return -np.inf
    all_time = np.concatenate(all_time)
    all_event = np.concatenate(all_event)
    ci = concordance_index_censored(all_event.astype(bool), all_time, all_preds)[0]
    return ci

def main():
    # Capture the original stdout
    original_stdout = sys.stdout
    log_path = "/content/drive/MyDrive/deepsurv_8-27-25_training_log.txt"

    # Open log file with context, and use Tee to write to both original_stdout and file
    with open(log_path, "w") as log_file:
        sys.stdout = Tee(original_stdout, log_file)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
        valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")

        train_df['Adjuvant Chemo'] = train_df['Adjuvant Chemo'].replace({'OBS': 0, 'ACT': 1})
        valid_df['Adjuvant Chemo'] = valid_df['Adjuvant Chemo'].replace({'OBS': 0, 'ACT': 1})

        binary_columns = ['Adjuvant Chemo', 'IS_MALE']
        for col in binary_columns:
            if col in train_df.columns:
                train_df[col] = train_df[col].astype(int)
            if col in valid_df.columns:
                valid_df[col] = valid_df[col].astype(int)

        survival_cols = ['OS_STATUS', 'OS_MONTHS']
        feature_cols = [col for col in train_df.columns if col not in survival_cols]

        X_train = train_df[feature_cols].values.astype(np.float32)
        X_valid = valid_df[feature_cols].values.astype(np.float32)

        scaler = StandardScaler().fit(X_train)
        X_train = scaler.transform(X_train).astype(np.float32)
        X_valid = scaler.transform(X_valid).astype(np.float32)

        y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
        y_train_event = train_df['OS_STATUS'].values.astype(np.float32)

        y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
        y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

        train_dataset = SurvivalDataset(X_train, y_train_time, y_train_event)
        valid_dataset = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

        batch_size = 32
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
        train_eval_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

        hidden_layer_configs = [
            [32],
            [64],
            [32, 16],
            [64, 32]
        ]
        dropout_rates = [0.0, 0.2, 0.5]
        learning_rates = [1e-4, 3e-4, 1e-3] #[0.001, 0.01]
        weight_decays = [0.0, 1e-4, 1e-3]

        num_epochs = 100
        best_ci = -np.inf
        best_hyperparams = None
        best_model_state = None
        results = []

        for layers in hidden_layer_configs:
            for dropout in dropout_rates:
                for lr in learning_rates:
                    for wd in weight_decays:
                        print(f"Training with layers={layers}, dropout={dropout}, lr={lr}, weight_decay={wd}")
                        model = DeepSurvMLP(in_features=X_train.shape[1],
                                            hidden_layers=layers,
                                            dropout=dropout,
                                            activation=nn.ReLU()).to(device)
                        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

                        best_epoch_ci = -np.inf
                        hist_train_ci, hist_val_ci, hist_train_loss = [], [], []

                        for epoch in range(num_epochs):
                            train_loss = train_model(model, optimizer, train_loader, device)
                            # CI on full train (no drop_last, no shuffle) and valid
                            train_ci = evaluate_model(model, train_eval_loader, device)
                            val_ci = evaluate_model(model, valid_loader, device)

                            hist_train_loss.append(train_loss)
                            hist_train_ci.append(train_ci)
                            hist_val_ci.append(val_ci)

                            if (epoch + 1) % 10 == 0:
                                print(f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}, Train CI: {train_ci:.4f}, Val CI: {val_ci:.4f}")

                            if val_ci > best_epoch_ci:
                                best_epoch_ci = val_ci

                            def preds_on(dl):
                                return np.concatenate([model(x.to(device)).detach().cpu().numpy().ravel()
                                                      for x,_,_ in dl])
                            if epoch == 0: val_prev = preds_on(valid_loader)
                            else:
                                val_now = preds_on(valid_loader)
                                rho = np.corrcoef(np.argsort(val_prev), np.argsort(val_now))[0,1]  # rank corr proxy
                                print(f"Val rank corr vs epoch1: {rho:.4f}")

                            vp = preds_on(valid_loader)
                            print("Frac |pred|>=20:", (np.abs(vp)>=20).mean())

                        # After epochs: save CI plot for this config
                        cfg = f"layers-{'-'.join(map(str, layers))}_drop{dropout}_lr{lr}_wd{wd}"
                        plot_path = os.path.join(output_dir, f"{current_date}_ci_{cfg}.png")
                        plt.figure()
                        plt.plot(range(1, num_epochs+1), hist_train_ci, label='Train CI')
                        plt.plot(range(1, num_epochs+1), hist_val_ci, label='Val CI')
                        plt.xlabel('Epoch'); plt.ylabel('Concordance Index'); plt.legend(); plt.grid(True, alpha=0.3)
                        plt.title(cfg)
                        plt.savefig(plot_path, dpi=150, bbox_inches='tight'); plt.close()
                        print(f"Saved CI plot to {plot_path}")
                        results.append({
                            'layers': layers,
                            'dropout': dropout,
                            'learning_rate': lr,
                            'weight_decay': wd,
                            'val_ci': best_epoch_ci
                        })
                        print(f"Finished config: Val CI = {best_epoch_ci:.4f}\n")
                        if best_epoch_ci > best_ci:
                            best_ci = best_epoch_ci
                            best_hyperparams = {
                                'layers': layers,
                                'dropout': dropout,
                                'learning_rate': lr,
                                'weight_decay': wd
                            }
                            best_model_state = model.state_dict()

        print("Best Hyperparameters:")
        print(best_hyperparams)
        print("Best Validation CI:", best_ci)

        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        results_df = pd.DataFrame(results)
        results_csv_path = os.path.join(output_dir, f"{current_date}_deepsurv_hyperparam_search_results.csv")
        results_df.to_csv(results_csv_path, index=False)
        print(f"Hyperparameter search results saved to {results_csv_path}")

        best_model_path = os.path.join(output_dir, f"{current_date}_best_deepsurv_model.pth")
        torch.save(best_model_state, best_model_path)
        print(f"Best model saved to {best_model_path}")

        # Perform eval on test set
        test_df = pd.read_csv("/content/drive/MyDrive/affyTest.csv")

        test_df['Adjuvant Chemo'] = test_df['Adjuvant Chemo'].replace({'OBS': 0, 'ACT': 1})
        for col in binary_columns:
            if col in test_df.columns:
                test_df[col] = test_df[col].astype(int)

        X_test = test_df[feature_cols].values.astype(np.float32)
        X_test = scaler.transform(X_test).astype(np.float32)
        y_test_time = test_df['OS_MONTHS'].values.astype(np.float32)
        y_test_event = test_df['OS_STATUS'].values.astype(np.float32)

        test_dataset = SurvivalDataset(X_test, y_test_time, y_test_event)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        # Rebuild best model and load saved weights
        best_model = DeepSurvMLP(
            in_features=X_train.shape[1],
            hidden_layers=best_hyperparams['layers'],
            dropout=best_hyperparams['dropout'],
            activation=nn.ReLU()
        ).to(device)
        best_model.load_state_dict(torch.load(best_model_path, map_location=device))

        test_ci = evaluate_model(best_model, test_loader, device)
        print(f"Test CI: {test_ci:.4f}")

        # Flush before exiting the with block
        sys.stdout.flush()

    # Restore original stdout for subsequent prints to the console
    sys.stdout = original_stdout
    print("Training completed. Check your log file at:", log_path)

if __name__ == "__main__":
    main()

  train_df['Adjuvant Chemo'] = train_df['Adjuvant Chemo'].replace({'OBS': 0, 'ACT': 1})
  valid_df['Adjuvant Chemo'] = valid_df['Adjuvant Chemo'].replace({'OBS': 0, 'ACT': 1})


Training with layers=[32], dropout=0.0, lr=0.0001, weight_decay=0.0
Epoch 10/100 - Loss: 1.3607, Train CI: 0.9598, Val CI: 0.6332
Epoch 20/100 - Loss: 0.8653, Train CI: 0.9816, Val CI: 0.6220
Epoch 30/100 - Loss: 0.7363, Train CI: 0.9859, Val CI: 0.6235
Epoch 40/100 - Loss: 0.6460, Train CI: 0.9859, Val CI: 0.6213
Epoch 50/100 - Loss: 0.5740, Train CI: 0.9857, Val CI: 0.6270
Epoch 60/100 - Loss: 0.5281, Train CI: 0.9882, Val CI: 0.6232
Epoch 70/100 - Loss: 0.4913, Train CI: 0.9901, Val CI: 0.6188
Epoch 80/100 - Loss: 0.4897, Train CI: 0.9889, Val CI: 0.6213
Epoch 90/100 - Loss: 0.4640, Train CI: 0.9903, Val CI: 0.6166
Epoch 100/100 - Loss: 0.4409, Train CI: 0.9896, Val CI: 0.6211
Saved CI plot to /content/drive/MyDrive/deepsurv_results/20250829_ci_layers-32_drop0.0_lr0.0001_wd0.0.png
Finished config: Val CI = 0.6483

Training with layers=[32], dropout=0.0, lr=0.0001, weight_decay=0.0001
Epoch 10/100 - Loss: 1.3618, Train CI: 0.9676, Val CI: 0.6070
Epoch 20/100 - Loss: 0.8715, Train CI: