<a href="https://colab.research.google.com/github/osun24/nsclc-adj-chemo/blob/main/TorchSurv_DeepSurv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary packages
!pip install torchsurv scikit-survival

# Import required packages
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.models import MLP
from torchsurv.losses import CoxPHLoss
from sksurv.metrics import concordance_index_censored

# (Optional) Mount Google Drive if you plan to load/save files there
from google.colab import drive
drive.mount('/content/drive')


Collecting torchsurv
  Downloading torchsurv-0.1.4-py3-none-any.whl.metadata (14 kB)
Collecting scikit-survival
  Downloading scikit_survival-0.24.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.9/48.9 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics (from torchsurv)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting ecos (from scikit-survival)
  Downloading ecos-2.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.0 kB)
Collecting osqp<1.0.0,>=0.6.3 (from scikit-survival)
  Downloading osqp-0.6.7.post3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting qdldl (from osqp<1.0.0,>=0.6.3->scikit-survival)
  Downloading qdldl-0.1.7.post5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==1

In [None]:
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.models import MLP
from torchsurv.losses import CoxPHLoss
from sksurv.metrics import concordance_index_censored

# Define a PyTorch Dataset for survival data
class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.float32)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.time[idx], self.event[idx]

# Training loop for one epoch
def train_model(model, criterion, optimizer, dataloader, device):
    model.train()
    running_loss = 0.0
    for x, time_vals, events in dataloader:
        x = x.to(device)
        time_vals = time_vals.to(device)
        events = events.to(device)
        optimizer.zero_grad()
        # Forward pass produces log hazard ratios
        outputs = model(x)
        loss = criterion(outputs, time_vals, events)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    return running_loss / len(dataloader.dataset)

# Evaluation: compute concordance index using sksurv's metric
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_time = []
    all_event = []
    with torch.no_grad():
        for x, time_vals, events in dataloader:
            x = x.to(device)
            preds = model(x)
            all_preds.append(preds.cpu().numpy().flatten())
            all_time.append(time_vals.numpy())
            all_event.append(events.numpy())
    all_preds = np.concatenate(all_preds)
    all_time = np.concatenate(all_time)
    all_event = np.concatenate(all_event)
    # Higher risk scores (log hazard ratios) should correspond to worse outcomes.
    ci = concordance_index_censored(all_event.astype(bool), all_time, all_preds)[0]
    return ci

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the Affy datasets from Google Drive
    train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
    valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")

    # Process binary columns (if applicable)
    binary_columns = ['Adjuvant Chemo', 'IS_MALE']
    for col in binary_columns:
        if col in train_df.columns:
            train_df[col] = train_df[col].astype(int)
        if col in valid_df.columns:
            valid_df[col] = valid_df[col].astype(int)

    # Define survival columns and feature columns
    survival_cols = ['OS_STATUS', 'OS_MONTHS']
    feature_cols = [col for col in train_df.columns if col not in survival_cols]

    X_train = train_df[feature_cols].values.astype(np.float32)
    y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
    y_train_event = train_df['OS_STATUS'].values.astype(np.float32)

    X_valid = valid_df[feature_cols].values.astype(np.float32)
    y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
    y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

    # Create PyTorch datasets and loaders
    train_dataset = SurvivalDataset(X_train, y_train_time, y_train_event)
    valid_dataset = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    # Hyperparameter grid for DeepSurv (MLP) model
    hidden_layer_configs = [
        [32],
        [64],
        [32, 16],
        [64, 32]
    ]
    dropout_rates = [0.0, 0.2, 0.5]
    learning_rates = [0.001, 0.01]
    weight_decays = [0.0, 1e-4, 1e-3]

    num_epochs = 100  # epochs per configuration
    best_ci = -np.inf
    best_hyperparams = None
    best_model_state = None
    results = []

    # Grid search over the hyperparameters
    for layers in hidden_layer_configs:
        for dropout in dropout_rates:
            for lr in learning_rates:
                for wd in weight_decays:
                    print(f"Training with layers={layers}, dropout={dropout}, lr={lr}, weight_decay={wd}")
                    # Initialize model as in TorchSurv's DeepSurv example
                    model = MLP(in_features=X_train.shape[1], layers=layers, dropout=dropout, activation=torch.relu).to(device)
                    criterion = CoxPHLoss()
                    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

                    best_epoch_ci = -np.inf
                    for epoch in range(num_epochs):
                        train_loss = train_model(model, criterion, optimizer, train_loader, device)
                        val_ci = evaluate_model(model, valid_loader, device)
                        if (epoch + 1) % 10 == 0:
                            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val CI: {val_ci:.4f}")
                        if val_ci > best_epoch_ci:
                            best_epoch_ci = val_ci
                    results.append({
                        'layers': layers,
                        'dropout': dropout,
                        'learning_rate': lr,
                        'weight_decay': wd,
                        'val_ci': best_epoch_ci
                    })
                    print(f"Finished config: Val CI = {best_epoch_ci:.4f}\n")
                    if best_epoch_ci > best_ci:
                        best_ci = best_epoch_ci
                        best_hyperparams = {
                            'layers': layers,
                            'dropout': dropout,
                            'learning_rate': lr,
                            'weight_decay': wd
                        }
                        best_model_state = model.state_dict()

    print("Best Hyperparameters:")
    print(best_hyperparams)
    print("Best Validation CI:", best_ci)

    # Save results and best model to Google Drive
    current_date = datetime.datetime.now().strftime("%Y%m%d")
    output_dir = "/content/drive/MyDrive/deepsurv_results"
    os.makedirs(output_dir, exist_ok=True)

    results_df = pd.DataFrame(results)
    results_csv_path = os.path.join(output_dir, f"{current_date}_deepsurv_hyperparam_search_results.csv")
    results_df.to_csv(results_csv_path, index=False)
    print(f"Hyperparameter search results saved to {results_csv_path}")

    best_model_path = os.path.join(output_dir, f"{current_date}_best_deepsurv_model.pth")
    torch.save(best_model_state, best_model_path)
    print(f"Best model saved to {best_model_path}")

if __name__ == "__main__":
    main()