<a href="https://colab.research.google.com/github/osun24/nsclc-adj-chemo/blob/main/TorchSurv_DeepSurv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary packages
!pip install torchsurv scikit-survival

# Import required packages
import os
import time
import datetime
import itertools
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sksurv.metrics import concordance_index_censored

# (Optional) Mount Google Drive if you plan to load/save files there
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import sys
import time
import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsurv.loss.cox import neg_partial_log_likelihood
from sksurv.metrics import concordance_index_censored
import warnings

warnings.filterwarnings("ignore", message="Ties in event time detected; using efron's method to handle ties.")

# Define a Tee class for logging output to both console and file
class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, data):
        for f in self.files:
            f.write(data)
    def flush(self):
        for f in self.files:
            f.flush()

# Define a custom MLP model for DeepSurv
class DeepSurvMLP(nn.Module):
    def __init__(self, in_features, hidden_layers, dropout=0.0, activation=nn.ReLU()):
        super(DeepSurvMLP, self).__init__()
        layers = []
        layers.append(nn.BatchNorm1d(in_features))
        current_dim = in_features
        for units in hidden_layers:
            layers.append(nn.Linear(current_dim, units))
            layers.append(activation)
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            current_dim = units
        layers.append(nn.Linear(current_dim, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Define a PyTorch Dataset for survival data
class SurvivalDataset(Dataset):
    def __init__(self, features, time_vals, events):
        self.x = torch.tensor(features, dtype=torch.float32)
        self.time = torch.tensor(time_vals, dtype=torch.float32)
        self.event = torch.tensor(events, dtype=torch.bool)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.time[idx], self.event[idx]

def train_model(model, optimizer, dataloader, device):
    model.train()
    running_loss = 0.0
    for x, time_vals, events in dataloader:
        x = x.to(device)
        time_vals = time_vals.to(device)
        events = events.to(device)
        optimizer.zero_grad()
        outputs = model(x)
        loss = neg_partial_log_likelihood(outputs, events, time_vals, reduction='mean')
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    return running_loss / len(dataloader.dataset)

def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_time = []
    all_event = []
    with torch.no_grad():
        for x, time_vals, events in dataloader:
            x = x.to(device)
            preds = model(x)
            all_preds.append(preds.cpu().numpy().flatten())
            all_time.append(time_vals.cpu().numpy())
            all_event.append(events.cpu().numpy())
    all_preds = np.concatenate(all_preds)
    if np.isnan(all_preds).any():
        print("Warning: NaN predictions detected, returning -inf for concordance index")
        return -np.inf
    all_time = np.concatenate(all_time)
    all_event = np.concatenate(all_event)
    ci = concordance_index_censored(all_event.astype(bool), all_time, all_preds)[0]
    return ci

def main():
    # Capture the original stdout
    original_stdout = sys.stdout
    log_path = "/content/drive/MyDrive/deepsurv_training_log.txt"

    # Open log file with context, and use Tee to write to both original_stdout and file
    with open(log_path, "w") as log_file:
        sys.stdout = Tee(original_stdout, log_file)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        train_df = pd.read_csv("/content/drive/MyDrive/affyTrain.csv")
        valid_df = pd.read_csv("/content/drive/MyDrive/affyValidation.csv")

        binary_columns = ['Adjuvant Chemo', 'IS_MALE']
        for col in binary_columns:
            if col in train_df.columns:
                train_df[col] = train_df[col].astype(int)
            if col in valid_df.columns:
                valid_df[col] = valid_df[col].astype(int)

        survival_cols = ['OS_STATUS', 'OS_MONTHS']
        feature_cols = [col for col in train_df.columns if col not in survival_cols]

        X_train = train_df[feature_cols].values.astype(np.float32)
        y_train_time = train_df['OS_MONTHS'].values.astype(np.float32)
        y_train_event = train_df['OS_STATUS'].values.astype(np.float32)

        X_valid = valid_df[feature_cols].values.astype(np.float32)
        y_valid_time = valid_df['OS_MONTHS'].values.astype(np.float32)
        y_valid_event = valid_df['OS_STATUS'].values.astype(np.float32)

        train_dataset = SurvivalDataset(X_train, y_train_time, y_train_event)
        valid_dataset = SurvivalDataset(X_valid, y_valid_time, y_valid_event)

        batch_size = 32
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

        hidden_layer_configs = [
            [32],
            [64],
            [32, 16],
            [64, 32]
        ]
        dropout_rates = [0.0, 0.2, 0.5]
        learning_rates = [0.001, 0.01]
        weight_decays = [0.0, 1e-4, 1e-3]

        num_epochs = 100
        best_ci = -np.inf
        best_hyperparams = None
        best_model_state = None
        results = []

        for layers in hidden_layer_configs:
            for dropout in dropout_rates:
                for lr in learning_rates:
                    for wd in weight_decays:
                        print(f"Training with layers={layers}, dropout={dropout}, lr={lr}, weight_decay={wd}")
                        model = DeepSurvMLP(in_features=X_train.shape[1],
                                            hidden_layers=layers,
                                            dropout=dropout,
                                            activation=nn.ReLU()).to(device)
                        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

                        best_epoch_ci = -np.inf
                        for epoch in range(num_epochs):
                            train_loss = train_model(model, optimizer, train_loader, device)
                            val_ci = evaluate_model(model, valid_loader, device)
                            if (epoch + 1) % 10 == 0:
                                print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val CI: {val_ci:.4f}")
                            if val_ci > best_epoch_ci:
                                best_epoch_ci = val_ci
                        results.append({
                            'layers': layers,
                            'dropout': dropout,
                            'learning_rate': lr,
                            'weight_decay': wd,
                            'val_ci': best_epoch_ci
                        })
                        print(f"Finished config: Val CI = {best_epoch_ci:.4f}\n")
                        if best_epoch_ci > best_ci:
                            best_ci = best_epoch_ci
                            best_hyperparams = {
                                'layers': layers,
                                'dropout': dropout,
                                'learning_rate': lr,
                                'weight_decay': wd
                            }
                            best_model_state = model.state_dict()

        print("Best Hyperparameters:")
        print(best_hyperparams)
        print("Best Validation CI:", best_ci)

        current_date = datetime.datetime.now().strftime("%Y%m%d")
        output_dir = "/content/drive/MyDrive/deepsurv_results"
        os.makedirs(output_dir, exist_ok=True)

        results_df = pd.DataFrame(results)
        results_csv_path = os.path.join(output_dir, f"{current_date}_deepsurv_hyperparam_search_results.csv")
        results_df.to_csv(results_csv_path, index=False)
        print(f"Hyperparameter search results saved to {results_csv_path}")

        best_model_path = os.path.join(output_dir, f"{current_date}_best_deepsurv_model.pth")
        torch.save(best_model_state, best_model_path)
        print(f"Best model saved to {best_model_path}")

        # Flush before exiting the with block
        sys.stdout.flush()

    # Restore original stdout for subsequent prints to the console
    sys.stdout = original_stdout
    print("Training completed. Check your log file at:", log_path)

if __name__ == "__main__":
    main()

Training with layers=[32], dropout=0.0, lr=0.001, weight_decay=0.0
Epoch 10/100 - Train Loss: 1.7421, Val CI: 0.5903
Epoch 20/100 - Train Loss: 1.5670, Val CI: 0.6334
Epoch 30/100 - Train Loss: 1.3232, Val CI: 0.6487
Epoch 40/100 - Train Loss: 1.2721, Val CI: 0.6191
Epoch 50/100 - Train Loss: 0.8419, Val CI: 0.6053
Epoch 60/100 - Train Loss: 1.2674, Val CI: 0.6336
Epoch 70/100 - Train Loss: 1.0104, Val CI: 0.6281
Epoch 80/100 - Train Loss: 0.9398, Val CI: 0.6167
Epoch 90/100 - Train Loss: 0.8571, Val CI: 0.6297
Epoch 100/100 - Train Loss: 0.8985, Val CI: 0.6580
Finished config: Val CI = 0.6624

Training with layers=[32], dropout=0.0, lr=0.001, weight_decay=0.0001
Epoch 10/100 - Train Loss: 2.2950, Val CI: 0.6258
Epoch 20/100 - Train Loss: 1.1915, Val CI: 0.6359
Epoch 30/100 - Train Loss: 1.0929, Val CI: 0.6160
Epoch 40/100 - Train Loss: 1.1576, Val CI: 0.6489
Epoch 50/100 - Train Loss: 0.9188, Val CI: 0.6573
Epoch 60/100 - Train Loss: 0.7924, Val CI: 0.6182
Epoch 70/100 - Train Loss: 0