# DeepHit Tutorial: Deep Learning for Competing Risks Survival Analysis

This tutorial provides a comprehensive guide to implementing DeepHit, a deep learning model for competing risks survival analysis developed by Changhee Lee et al. in 2018.

**Reference:** [DeepHit GitHub Repository](https://github.com/chl8856/DeepHit)

**Paper:** [DeepHit: A Deep Learning Approach to Survival Analysis with Competing Risks](http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit)


## 1. Overview

DeepHit is a deep learning model designed for survival analysis with competing risks. Unlike traditional survival models like Cox Proportional Hazards, DeepHit can handle:

- **Multiple competing events**: Model different types of events that can occur
- **Non-linear relationships**: Capture complex interactions between covariates and survival outcomes
- **Time-dependent effects**: Model how the effect of covariates changes over time
- **No proportional hazards assumption**: More flexible than Cox models

### Key Advantages:
1. **Flexibility**: Can model complex non-linear relationships
2. **Competing Risks**: Handles multiple event types simultaneously
3. **No assumptions**: Doesn't require proportional hazards assumption
4. **Performance**: Often outperforms traditional methods on complex datasets


## 2. How It Works

DeepHit uses a deep neural network to estimate the joint distribution of survival times and event types. The model architecture consists of:

### Architecture:
1. **Shared Sub-network**: Learns common representations from input features
2. **Cause-Specific Sub-networks**: Separate networks for each competing event type
3. **Output Layer**: Produces probability mass functions (PMF) over discrete time intervals for each event type

### Loss Function:
The model uses a combination of:
- **Likelihood loss**: Maximizes the probability of observed events
- **Ranking loss**: Ensures correct ordering of survival times
- **Cause-specific loss**: Handles competing risks appropriately

### Training Process:
1. Input features are passed through shared layers
2. Outputs are split into cause-specific branches
3. Each branch produces PMF over time intervals
4. Loss is computed based on observed events and times
5. Model is optimized using backpropagation


## 3. Setup R Libraries

First, let's install and load all necessary libraries.


In [None]:
# Install required packages (run once if not installed)
# install.packages(c("torch", "dplyr", "ggplot2", "tidyr", "caret", "scales", "viridis", "RColorBrewer"))


In [None]:
# Load required libraries
library(torch)
library(dplyr)
library(ggplot2)
library(tidyr)
library(caret)
library(scales)
library(viridis)
library(RColorBrewer)

# Set random seeds for reproducibility
set.seed(42)
torch_manual_seed(42)

# Check if CUDA is available
cat("CUDA available:", cuda_is_available(), "\n")
cat("Libraries loaded successfully!\n")


## 4. Create Synthetic Data

We'll create synthetic data similar to the DeepHit repository's synthetic dataset. The data will include:
- Multiple features (covariates)
- Event times (time to event or censoring)
- Event indicators (0: censored, 1: event type 1, 2: event type 2)


In [None]:
generate_synthetic_competing_risks_data <- function(n_samples = 1000, n_features = 12, n_risks = 2, seed = 42) {
  # Generate synthetic competing risks survival data.
  # 
  # Parameters:
  # n_samples: Number of samples
  # n_features: Number of features
  # n_risks: Number of competing risks (event types)
  # seed: Random seed
  #
  # Returns:
  # data.frame with features, time, and event columns
  set.seed(seed)
  
  # Generate features
  X <- matrix(rnorm(n_samples * n_features), nrow = n_samples, ncol = n_features)
  
  # Create feature names
  feature_names <- paste0("x", 1:n_features)
  colnames(X) <- feature_names
  
  # Generate event times and types based on features
  times <- numeric(n_samples)
  events <- integer(n_samples)
  
  for (i in 1:n_samples) {
    # Create hazard functions that depend on features
    # Risk 1: depends on first few features
    hazard_1 <- exp(0.5 * X[i, 1] + 0.3 * X[i, 2] - 0.2 * X[i, 3])
    
    # Risk 2: depends on different features
    hazard_2 <- exp(0.4 * X[i, 4] + 0.3 * X[i, 5] - 0.1 * X[i, 6])
    
    # Generate time to event for each risk
    time_1 <- ifelse(hazard_1 > 0, rexp(1, rate = hazard_1), Inf)
    time_2 <- ifelse(hazard_2 > 0, rexp(1, rate = hazard_2), Inf)
    
    # Determine which event occurs first
    min_time <- min(time_1, time_2)
    
    # Add some censoring (30% censoring rate)
    censor_time <- rexp(1, rate = 1/15)
    
    if (censor_time < min_time) {
      times[i] <- censor_time
      events[i] <- 0  # Censored
    } else {
      times[i] <- min_time
      events[i] <- ifelse(time_1 < time_2, 1, 2)
    }
  }
  
  # Create data frame
  data <- as.data.frame(X)
  data$time <- times
  data$event <- events
  
  return(data)
}

# Generate synthetic data
data <- generate_synthetic_competing_risks_data(n_samples = 2000, n_features = 12, n_risks = 2)

cat("Data shape:", nrow(data), "x", ncol(data), "\n")
cat("\nEvent distribution:\n")
print(table(data$event))
cat("\nCensoring rate:", mean(data$event == 0) * 100, "%\n")
cat("\nFirst few rows:\n")
print(head(data))

# Save to CSV
write.csv(data, 'synthetic_comprisk.csv', row.names = FALSE)
cat("\nData saved to 'synthetic_comprisk.csv'\n")


## 5. Split Data

Split the data into training, validation, and testing sets.


In [None]:
# Load data if not already in memory
# data <- read.csv('synthetic_comprisk.csv')

# Separate features and targets
X <- data %>% select(-time, -event)
y <- data %>% select(time, event)

# Split into train and test (80/20)
train_idx <- createDataPartition(y$event, p = 0.8, list = FALSE)
X_train <- X[train_idx, ]
X_test <- X[-train_idx, ]
y_train <- y[train_idx, ]
y_test <- y[-train_idx, ]

# Further split training data into train and validation (80/20 of training)
val_idx <- createDataPartition(y_train$event, p = 0.2, list = FALSE)
X_val <- X_train[val_idx, ]
X_train <- X_train[-val_idx, ]
y_val <- y_train[val_idx, ]
y_train <- y_train[-val_idx, ]

cat("Training set:", nrow(X_train), "samples\n")
cat("Validation set:", nrow(X_val), "samples\n")
cat("Test set:", nrow(X_test), "samples\n")
cat("\nTraining event distribution:\n")
print(table(y_train$event))
cat("\nTest event distribution:\n")
print(table(y_test$event))


## 6. Data Preprocessing

Preprocess the data by standardizing features and preparing time intervals for DeepHit.


In [None]:
# Standardize features
preProc <- preProcess(X_train, method = c("center", "scale"))
X_train_scaled <- predict(preProc, X_train)
X_val_scaled <- predict(preProc, X_val)
X_test_scaled <- predict(preProc, X_test)

# Convert to matrices for torch
X_train_scaled <- as.matrix(X_train_scaled)
X_val_scaled <- as.matrix(X_val_scaled)
X_test_scaled <- as.matrix(X_test_scaled)

# Create discrete time intervals for DeepHit
# DeepHit works with discrete time intervals
max_time <- max(max(y_train$time), max(y_val$time), max(y_test$time))
num_intervals <- 50  # Number of discrete time intervals
time_intervals <- seq(0, max_time, length.out = num_intervals + 1)

discretize_time <- function(times, intervals) {
  # Convert continuous times to discrete interval indices
  return(findInterval(times, intervals, rightmost.closed = TRUE) - 1)
}

y_train_discrete <- discretize_time(y_train$time, time_intervals)
y_val_discrete <- discretize_time(y_val$time, time_intervals)
y_test_discrete <- discretize_time(y_test$time, time_intervals)

# Clip to valid range
y_train_discrete <- pmax(0, pmin(y_train_discrete, num_intervals - 1))
y_val_discrete <- pmax(0, pmin(y_val_discrete, num_intervals - 1))
y_test_discrete <- pmax(0, pmin(y_test_discrete, num_intervals - 1))

cat("Time intervals:", num_intervals, "intervals\n")
cat("Time range: 0 to", round(max_time, 2), "\n")
cat("\nTraining discrete time distribution (first 10 intervals):\n")
print(table(y_train_discrete)[1:min(10, length(table(y_train_discrete)))])


## 7. DeepHit Implementation

Now we'll implement the DeepHit model. This is a custom torch implementation based on the original paper.


In [None]:
# DeepHit Dataset class
DeepHitDataset <- dataset(
  name = "DeepHitDataset",
  initialize = function(X, times, events) {
    self$X <- torch_tensor(X, dtype = torch_float32())
    self$times <- torch_tensor(times, dtype = torch_long())
    self$events <- torch_tensor(events, dtype = torch_long())
  },
  .getitem = function(i) {
    list(self$X[i, ], self$times[i], self$events[i])
  },
  .length = function() {
    self$X$shape[1]
  }
)

# DeepHit Network architecture
DeepHitNetwork <- nn_module(
  "DeepHitNetwork",
  initialize = function(input_dim, hidden_dims, num_risks, num_intervals, dropout = 0.1) {
    self$num_risks <- num_risks
    self$num_intervals <- num_intervals
    
    # Shared layers
    shared_layers <- list()
    prev_dim <- input_dim
    
    for (hidden_dim in hidden_dims) {
      shared_layers <- append(shared_layers, nn_linear(prev_dim, hidden_dim))
      shared_layers <- append(shared_layers, nn_batch_norm1d(hidden_dim))
      shared_layers <- append(shared_layers, nn_relu())
      shared_layers <- append(shared_layers, nn_dropout(dropout))
      prev_dim <- hidden_dim
    }
    
    self$shared_layers <- nn_sequential(!!!shared_layers)
    
    # Cause-specific layers
    self$risk_layers <- nn_module_list()
    for (i in 1:num_risks) {
      self$risk_layers$append(
        nn_sequential(
          nn_linear(prev_dim, as.integer(prev_dim / 2)),
          nn_batch_norm1d(as.integer(prev_dim / 2)),
          nn_relu(),
          nn_dropout(dropout),
          nn_linear(as.integer(prev_dim / 2), num_intervals)
        )
      )
    }
  },
  forward = function(x) {
    # Shared representation
    shared <- self$shared_layers(x)
    
    # Cause-specific outputs
    outputs <- list()
    for (i in 1:self$num_risks) {
      output <- self$risk_layers[[i]](shared)
      outputs[[i]] <- output
    }
    
    # Stack outputs: [batch_size, num_risks, num_intervals]
    return(torch_stack(outputs, dim = 2))
  }
)

# DeepHit loss function
deephit_loss <- function(pred, times, events, alpha = 0.5, sigma = 0.1) {
  # DeepHit loss function combining likelihood and ranking losses.
  # 
  # Parameters:
  # pred: Model predictions [batch_size, num_risks, num_intervals]
  # times: Discrete time indices [batch_size]
  # events: Event indicators [batch_size] (0: censored, 1+: event types)
  # alpha: Weight for ranking loss
  # sigma: Parameter for ranking loss
  batch_size <- pred$shape[1]
  num_risks <- pred$shape[2]
  num_intervals <- pred$shape[3]
  
  # Apply softmax to get probabilities
  pred_probs <- nnf_softmax(pred, dim = 3)  # [batch_size, num_risks, num_intervals]
  
  # Likelihood loss
  likelihood_loss <- 0.0
  
  for (i in 1:batch_size) {
    time_idx <- as.integer(times[i]$item()) + 1  # R is 1-indexed
    event <- as.integer(events[i]$item())
    
    if (event == 0) {  # Censored
      # For censored, sum probabilities of all risks up to censoring time
      surv_prob <- 1.0 - torch_sum(pred_probs[i, , 1:time_idx])
      likelihood_loss <- likelihood_loss - torch_log(surv_prob + 1e-8)
    } else {  # Event occurred
      # Probability of specific event at specific time
      event_prob <- pred_probs[i, event, time_idx]
      likelihood_loss <- likelihood_loss - torch_log(event_prob + 1e-8)
    }
  }
  
  likelihood_loss <- likelihood_loss / batch_size
  
  # Ranking loss (simplified version)
  ranking_loss <- 0.0
  count <- 0
  
  for (i in 1:batch_size) {
    if (events[i]$item() == 0) {  # Skip censored
      next
    }
    
    time_i <- as.integer(times[i]$item()) + 1
    event_i <- as.integer(events[i]$item())
    
    # Cumulative incidence for event i
    cif_i <- torch_cumsum(pred_probs[i, event_i, ], dim = 1)
    
    for (j in 1:batch_size) {
      if (i == j || events[j]$item() == 0) {
        next
      }
      
      time_j <- as.integer(times[j]$item()) + 1
      
      if (time_i < time_j) {
        # i should have higher risk than j
        cif_j <- torch_cumsum(pred_probs[j, event_i, ], dim = 1)
        diff <- cif_j[time_j] - cif_i[time_i]
        ranking_loss <- ranking_loss + torch_exp(-diff / sigma)
        count <- count + 1
      } else if (time_j < time_i) {
        # j should have higher risk than i
        cif_j <- torch_cumsum(pred_probs[j, event_i, ], dim = 1)
        diff <- cif_i[time_i] - cif_j[time_j]
        ranking_loss <- ranking_loss + torch_exp(-diff / sigma)
        count <- count + 1
      }
    }
  }
  
  if (count > 0) {
    ranking_loss <- ranking_loss / count
  }
  
  total_loss <- likelihood_loss + alpha * ranking_loss
  
  return(list(total = total_loss, likelihood = likelihood_loss, ranking = ranking_loss))
}

# Model parameters
input_dim <- ncol(X_train_scaled)
hidden_dims <- c(64, 32)
num_risks <- 2  # Two competing risks
num_intervals <- num_intervals

# Create model
model <- DeepHitNetwork(input_dim, hidden_dims, num_risks, num_intervals, dropout = 0.1)

# Move to GPU if available
device <- if (cuda_is_available()) "cuda" else "cpu"
model <- model$to(device = device)

total_params <- sum(sapply(model$parameters, function(p) prod(p$shape)))
cat("Model created with", total_params, "parameters\n")
cat("Using device:", device, "\n")


In [None]:
# Training setup
learning_rate <- 0.001
batch_size <- 64
num_epochs <- 100

optimizer <- optim_adam(model$parameters, lr = learning_rate)

# Create datasets
train_dataset <- DeepHitDataset(X_train_scaled, y_train_discrete, y_train$event)
val_dataset <- DeepHitDataset(X_val_scaled, y_val_discrete, y_val$event)

train_dataloader <- dataloader(train_dataset, batch_size = batch_size, shuffle = TRUE)
val_dataloader <- dataloader(val_dataset, batch_size = batch_size, shuffle = FALSE)

# Training loop
train_losses <- numeric(num_epochs)
val_losses <- numeric(num_epochs)
best_val_loss <- Inf
patience <- 10
patience_counter <- 0

for (epoch in 1:num_epochs) {
  # Training
  model$train()
  epoch_train_loss <- 0.0
  batch_count <- 0
  
  coro::loop(for (batch in train_dataloader) {
    batch_X <- batch[[1]]$to(device = device)
    batch_times <- batch[[2]]$to(device = device)
    batch_events <- batch[[3]]$to(device = device)
    
    optimizer$zero_grad()
    pred <- model(batch_X)
    loss_result <- deephit_loss(pred, batch_times, batch_events)
    loss_result$total$backward()
    optimizer$step()
    
    epoch_train_loss <- epoch_train_loss + loss_result$total$item()
    batch_count <- batch_count + 1
  })
  
  train_losses[epoch] <- epoch_train_loss / batch_count
  
  # Validation
  model$eval()
  epoch_val_loss <- 0.0
  batch_count <- 0
  
  with_no_grad({
    coro::loop(for (batch in val_dataloader) {
      batch_X <- batch[[1]]$to(device = device)
      batch_times <- batch[[2]]$to(device = device)
      batch_events <- batch[[3]]$to(device = device)
      
      pred <- model(batch_X)
      loss_result <- deephit_loss(pred, batch_times, batch_events)
      epoch_val_loss <- epoch_val_loss + loss_result$total$item()
      batch_count <- batch_count + 1
    })
  })
  
  val_losses[epoch] <- epoch_val_loss / batch_count
  
  # Early stopping
  if (epoch_val_loss < best_val_loss) {
    best_val_loss <- epoch_val_loss
    patience_counter <- 0
    # Save best model
    torch_save(model$state_dict(), 'best_deephit_model.pth')
  } else {
    patience_counter <- patience_counter + 1
  }
  
  if (epoch %% 10 == 0) {
    cat(sprintf("Epoch [%d/%d], Train Loss: %.4f, Val Loss: %.4f\n", 
                epoch, num_epochs, train_losses[epoch], val_losses[epoch]))
  }
  
  if (patience_counter >= patience) {
    cat(sprintf("Early stopping at epoch %d\n", epoch))
    break
  }
}

# Load best model
model$load_state_dict(torch_load('best_deephit_model.pth'))

cat("\nTraining completed!\n")


In [None]:
# Plot training curves
loss_df <- data.frame(
  epoch = 1:length(train_losses),
  train_loss = train_losses,
  val_loss = val_losses
) %>%
  pivot_longer(cols = c(train_loss, val_loss), names_to = "type", values_to = "loss")

ggplot(loss_df, aes(x = epoch, y = loss, color = type)) +
  geom_line(linewidth = 1) +
  labs(x = "Epoch", y = "Loss", 
       title = "DeepHit Training Curves",
       color = "Type") +
  theme_minimal() +
  theme(plot.title = element_text(size = 14, face = "bold"),
        legend.position = "bottom")


## 8. Prediction Model Evaluation

Evaluate the model's predictions on the test set.


In [None]:
# Make predictions on test set
model$eval()
test_dataset <- DeepHitDataset(X_test_scaled, y_test_discrete, y_test$event)
test_dataloader <- dataloader(test_dataset, batch_size = batch_size, shuffle = FALSE)

all_preds <- list()
all_times <- list()
all_events <- list()

with_no_grad({
  coro::loop(for (batch in test_dataloader) {
    batch_X <- batch[[1]]$to(device = device)
    pred <- model(batch_X)
    
    # Convert to probabilities
    pred_probs <- nnf_softmax(pred, dim = 3)
    
    all_preds <- append(all_preds, list(pred_probs$cpu()))
    all_times <- append(all_times, list(batch[[2]]$cpu()))
    all_events <- append(all_events, list(batch[[3]]$cpu()))
  })
})

# Concatenate all predictions
test_preds <- torch_cat(all_preds, dim = 1)  # [n_samples, num_risks, num_intervals]
test_times <- torch_cat(all_times, dim = 1)
test_events <- torch_cat(all_events, dim = 1)

# Convert to R arrays
test_preds <- as.array(test_preds)
test_times <- as.array(test_times)
test_events <- as.array(test_events)

cat("Test predictions shape:", paste(dim(test_preds), collapse = " x "), "\n")
cat("Test times shape:", length(test_times), "\n")
cat("Test events shape:", length(test_events), "\n")


In [None]:
# Calculate Cumulative Incidence Functions (CIF)
calculate_cif <- function(pred_probs) {
  # Calculate Cumulative Incidence Function from predicted probabilities
  # pred_probs: [n_samples, num_risks, num_intervals]
  cif <- apply(pred_probs, c(1, 2), cumsum)  # Cumulative sum over time
  # Reshape to [n_samples, num_risks, num_intervals]
  dim(cif) <- dim(pred_probs)
  return(cif)
}

test_cif <- calculate_cif(test_preds)

# Plot CIF for a few samples
plot_data <- data.frame()
for (risk_idx in 1:num_risks) {
  for (i in 1:min(10, dim(test_cif)[1])) {
    plot_data <- rbind(plot_data, data.frame(
      time = time_intervals[2:length(time_intervals)],
      cif = test_cif[i, risk_idx, ],
      sample = paste0("Sample ", i),
      risk = paste0("Risk ", risk_idx)
    ))
  }
}

ggplot(plot_data, aes(x = time, y = cif, color = sample)) +
  geom_line(alpha = 0.6, linewidth = 1) +
  facet_wrap(~risk, scales = "free") +
  labs(x = "Time", y = "Cumulative Incidence",
       title = "Cumulative Incidence Function - First 10 Samples") +
  theme_minimal() +
  theme(plot.title = element_text(size = 12, face = "bold"),
        legend.position = "none")

# Calculate mean CIF for each risk
mean_cif_risk1 <- apply(test_cif[, 1, ], 2, mean)
mean_cif_risk2 <- apply(test_cif[, 2, ], 2, mean)

mean_cif_df <- data.frame(
  time = time_intervals[2:length(time_intervals)],
  risk1 = mean_cif_risk1,
  risk2 = mean_cif_risk2
) %>%
  pivot_longer(cols = c(risk1, risk2), names_to = "risk", values_to = "cif")

ggplot(mean_cif_df, aes(x = time, y = cif, color = risk)) +
  geom_line(linewidth = 1.5) +
  labs(x = "Time", y = "Cumulative Incidence",
       title = "Mean Cumulative Incidence Functions",
       color = "Risk") +
  theme_minimal() +
  theme(plot.title = element_text(size = 14, face = "bold"),
        legend.position = "bottom")


## 9. Performance Metrics

Calculate various performance metrics to evaluate the model.


In [None]:
calculate_concordance_index <- function(pred_cif, times, events, risk_idx) {
  # Calculate concordance index (C-index) for a specific risk.
  # Simplified version - compares predicted risk at event time.
  # Get predicted risk scores (CIF at event time or max CIF)
  risk_scores <- numeric()
  valid_indices <- integer()
  
  for (i in 1:length(times)) {
    if (events[i] > 0) {  # Only consider uncensored
      time_idx <- min(as.integer(times[i]) + 1, dim(pred_cif)[3])
      risk_score <- pred_cif[i, risk_idx, time_idx]
      risk_scores <- c(risk_scores, risk_score)
      valid_indices <- c(valid_indices, i)
    }
  }
  
  if (length(risk_scores) < 2) {
    return(0.5)
  }
  
  valid_times <- times[valid_indices]
  valid_events <- events[valid_indices]
  
  # Calculate concordance
  concordant <- 0
  total <- 0
  
  for (i in 1:length(valid_times)) {
    for (j in (i+1):length(valid_times)) {
      if (valid_times[i] < valid_times[j]) {
        if (risk_scores[i] > risk_scores[j]) {
          concordant <- concordant + 1
        }
        total <- total + 1
      } else if (valid_times[j] < valid_times[i]) {
        if (risk_scores[j] > risk_scores[i]) {
          concordant <- concordant + 1
        }
        total <- total + 1
      }
    }
  }
  
  return(ifelse(total > 0, concordant / total, 0.5))
}

calculate_brier_score <- function(pred_cif, times, events, risk_idx, time_points) {
  # Calculate Brier score at specific time points
  brier_scores <- numeric()
  
  for (t in time_points) {
    # Find closest interval
    t_idx <- findInterval(t, time_intervals[2:length(time_intervals)], rightmost.closed = TRUE) + 1
    t_idx <- min(t_idx, dim(pred_cif)[3])
    
    # Get predicted CIF at time t
    pred_cif_t <- pred_cif[, risk_idx, t_idx]
    
    # Get true outcomes (1 if event occurred before t, 0 otherwise)
    true_outcomes <- as.numeric((events == risk_idx) & (y_test$time <= time_intervals[t_idx + 1]))
    
    # Calculate Brier score
    brier <- mean((pred_cif_t - true_outcomes)^2)
    brier_scores <- c(brier_scores, brier)
  }
  
  return(brier_scores)
}

# Calculate metrics for each risk
cat("Performance Metrics:\n")
cat(paste(rep("=", 50), collapse = ""), "\n")

for (risk_idx in 1:num_risks) {
  c_index <- calculate_concordance_index(test_cif, test_times, test_events, risk_idx)
  cat(sprintf("\nRisk %d:\n", risk_idx))
  cat(sprintf("  Concordance Index (C-index): %.4f\n", c_index))
  
  # Brier score at different time points
  time_points <- seq(0, max_time, length.out = 10)
  brier_scores <- calculate_brier_score(test_cif, test_times, test_events, risk_idx, time_points)
  cat(sprintf("  Mean Brier Score: %.4f\n", mean(brier_scores)))
  cat(sprintf("  Integrated Brier Score: %.4f\n", 
              integrate(function(x) approx(time_points, brier_scores, x)$y, 
                       0, max_time)$value / max_time))
}

# Overall metrics
cat("\n", paste(rep("=", 50), collapse = ""), "\n")
cat("Overall Model Performance:\n")
cat("Test samples:", length(test_times), "\n")
cat("Event rate:", mean(test_events > 0) * 100, "%\n")
cat("Censoring rate:", mean(test_events == 0) * 100, "%\n")


In [None]:
# Visualize Brier scores over time
time_points <- seq(0, max_time, length.out = 20)

brier_df <- data.frame()
for (risk_idx in 1:num_risks) {
  brier_scores <- calculate_brier_score(test_cif, test_times, test_events, risk_idx, time_points)
  brier_df <- rbind(brier_df, data.frame(
    time = time_points,
    brier_score = brier_scores,
    risk = paste0("Risk ", risk_idx)
  ))
}

ggplot(brier_df, aes(x = time, y = brier_score, color = risk)) +
  geom_line(linewidth = 1.5) +
  geom_point(size = 2) +
  facet_wrap(~risk, scales = "free") +
  labs(x = "Time", y = "Brier Score",
       title = "Brier Score Over Time") +
  theme_minimal() +
  theme(plot.title = element_text(size = 12, face = "bold"),
        legend.position = "none")


## 10. Risk Stratification

Stratify patients into risk groups based on predicted cumulative incidence.


In [None]:
# Calculate risk scores for each patient
# Use maximum CIF value as risk score
risk_scores_risk1 <- apply(test_cif[, 1, ], 1, max)
risk_scores_risk2 <- apply(test_cif[, 2, ], 1, max)

# Combine risk scores (weighted sum or max)
combined_risk_scores <- (risk_scores_risk1 + risk_scores_risk2) / 2

# Stratify into risk groups (tertiles)
risk_groups <- cut(combined_risk_scores, 
                   breaks = quantile(combined_risk_scores, probs = c(0, 1/3, 2/3, 1)),
                   labels = c("Low Risk", "Medium Risk", "High Risk"),
                   include.lowest = TRUE)

# Add to test data
y_test_stratified <- y_test
y_test_stratified$risk_group <- risk_groups
y_test_stratified$risk_score <- combined_risk_scores

cat("Risk Stratification:\n")
cat(paste(rep("=", 50), collapse = ""), "\n")
print(table(y_test_stratified$risk_group))
cat("\nRisk Score Statistics by Group:\n")
print(y_test_stratified %>% 
      group_by(risk_group) %>% 
      summarise(
        mean = mean(risk_score),
        sd = sd(risk_score),
        min = min(risk_score),
        max = max(risk_score),
        .groups = "drop"))


In [None]:
# Visualize risk stratification
# Plot 1: Risk score distribution
p1 <- ggplot(y_test_stratified, aes(x = risk_score, fill = risk_group)) +
  geom_histogram(alpha = 0.6, bins = 20, color = "black") +
  labs(x = "Risk Score", y = "Frequency",
       title = "Risk Score Distribution by Group",
       fill = "Risk Group") +
  theme_minimal() +
  theme(plot.title = element_text(size = 12, face = "bold"),
        legend.position = "bottom")

# Plot 2: Kaplan-Meier-like curves for each risk group
survival_data <- data.frame()
for (group in levels(risk_groups)) {
  group_indices <- which(y_test_stratified$risk_group == group)
  
  # Calculate survival probability (1 - CIF)
  group_cif <- apply(test_cif[group_indices, , ], c(2, 3), mean)  # Average over samples
  
  # Overall survival (1 - sum of all CIFs)
  overall_survival <- 1 - colSums(group_cif)
  
  survival_data <- rbind(survival_data, data.frame(
    time = time_intervals[2:length(time_intervals)],
    survival = overall_survival,
    group = group
  ))
}

p2 <- ggplot(survival_data, aes(x = time, y = survival, color = group)) +
  geom_line(linewidth = 1.5) +
  labs(x = "Time", y = "Survival Probability",
       title = "Survival Curves by Risk Group",
       color = "Risk Group") +
  ylim(0, 1) +
  theme_minimal() +
  theme(plot.title = element_text(size = 12, face = "bold"),
        legend.position = "bottom")

# Combine plots
library(gridExtra)
grid.arrange(p1, p2, ncol = 2)

# Event rates by risk group
cat("\nEvent Rates by Risk Group:\n")
cat(paste(rep("=", 50), collapse = ""), "\n")
print(table(y_test_stratified$risk_group, y_test_stratified$event))


## 11. Feature Importance Analysis and Visualization

Analyze which features are most important for predictions.


In [None]:
# Permutation-based feature importance
calculate_permutation_importance <- function(model, X, times, events, n_repeats = 10) {
  # Calculate feature importance using permutation
  model$eval()
  
  # Baseline prediction
  with_no_grad({
    X_tensor <- torch_tensor(X, dtype = torch_float32())$to(device = device)
    baseline_pred <- model(X_tensor)
    baseline_pred_probs <- nnf_softmax(baseline_pred, dim = 3)
    baseline_cif <- calculate_cif(as.array(baseline_pred_probs$cpu()))
    baseline_loss_result <- deephit_loss(
      baseline_pred,
      torch_tensor(times, dtype = torch_long())$to(device = device),
      torch_tensor(events, dtype = torch_long())$to(device = device)
    )
  })
  
  baseline_loss <- baseline_loss_result$total$item()
  
  # Permute each feature
  n_features <- ncol(X)
  importances <- numeric(n_features)
  
  for (feat_idx in 1:n_features) {
    permuted_losses <- numeric(n_repeats)
    
    for (rep in 1:n_repeats) {
      X_permuted <- X
      X_permuted[, feat_idx] <- sample(X_permuted[, feat_idx])
      
      with_no_grad({
        X_tensor <- torch_tensor(X_permuted, dtype = torch_float32())$to(device = device)
        permuted_pred <- model(X_tensor)
        permuted_loss_result <- deephit_loss(
          permuted_pred,
          torch_tensor(times, dtype = torch_long())$to(device = device),
          torch_tensor(events, dtype = torch_long())$to(device = device)
        )
        permuted_losses[rep] <- permuted_loss_result$total$item()
      })
    }
    
    # Importance is increase in loss
    importances[feat_idx] <- mean(permuted_losses) - baseline_loss
  }
  
  return(importances)
}

# Calculate feature importance (use subset for speed)
cat("Calculating feature importance (this may take a while)...\n")
n_subset <- min(100, nrow(X_test_scaled))
feature_importance <- calculate_permutation_importance(
  model, 
  X_test_scaled[1:n_subset, ],
  y_test_discrete[1:n_subset],
  y_test$event[1:n_subset],
  n_repeats = 5
)

# Create importance data frame
importance_df <- data.frame(
  feature = colnames(X_test),
  importance = feature_importance
) %>%
  arrange(desc(importance))

cat("\nFeature Importance (Top 10):\n")
cat(paste(rep("=", 50), collapse = ""), "\n")
print(head(importance_df, 10))


In [None]:
# Visualize feature importance
top_features <- head(importance_df, 12)

ggplot(top_features, aes(x = reorder(feature, importance), y = importance)) +
  geom_bar(stat = "identity", fill = viridis(nrow(top_features)), color = "black") +
  coord_flip() +
  labs(x = "Feature", y = "Importance (Increase in Loss)",
       title = "Feature Importance Analysis (Permutation-based)") +
  theme_minimal() +
  theme(plot.title = element_text(size = 14, face = "bold"))

# Correlation between features and risk scores
feature_risk_corr <- sapply(colnames(X_test), function(feat) {
  abs(cor(X_test[[feat]], combined_risk_scores))
})

corr_df <- data.frame(
  feature = names(feature_risk_corr),
  correlation = feature_risk_corr
) %>%
  arrange(desc(correlation))

top_corr <- head(corr_df, 12)

ggplot(top_corr, aes(x = reorder(feature, correlation), y = correlation)) +
  geom_bar(stat = "identity", fill = plasma(nrow(top_corr)), color = "black") +
  coord_flip() +
  labs(x = "Feature", y = "Absolute Correlation with Risk Score",
       title = "Feature-Risk Score Correlation") +
  theme_minimal() +
  theme(plot.title = element_text(size = 14, face = "bold"))


## 12. Model Interpretation with SHAP

Use SHAP (SHapley Additive exPlanations) to interpret model predictions.


In [None]:
# Install SHAP for R if not already installed
# install.packages("fastshap")

if (requireNamespace("fastshap", quietly = TRUE)) {
  library(fastshap)
  
  # Create a wrapper function for model predictions
  model_predict_wrapper <- function(X) {
    # Wrapper function for model predictions
    model$eval()
    with_no_grad({
      X_tensor <- torch_tensor(X, dtype = torch_float32())$to(device = device)
      pred <- model(X_tensor)
      pred_probs <- nnf_softmax(pred, dim = 3)
      # Return max CIF as risk score
      pred_probs_array <- as.array(pred_probs$cpu())
      cif <- calculate_cif(pred_probs_array)
      return(apply(cif[, 1, ], 1, max))  # Risk 1, max CIF
    })
  }
  
  # Use subset of data for SHAP (computationally expensive)
  X_shap <- X_test_scaled[1:20, ]
  X_shap_background <- X_train_scaled[1:50, ]
  
  # Create SHAP explainer
  cat("Creating SHAP explainer (this may take a while)...\n")
  explainer <- fastshap::explain(
    object = model_predict_wrapper,
    X = X_shap,
    nsim = 100,
    pred_wrapper = model_predict_wrapper,
    feature_names = colnames(X_test)
  )
  
  cat("\nSHAP analysis completed!\n")
  cat("Note: For detailed SHAP visualizations, consider using the 'shapr' package\n")
  cat("or exporting to Python for full SHAP library support.\n")
  
} else {
  cat("SHAP not installed. Install with: install.packages('fastshap')\n")
  cat("\nFor now, showing alternative interpretation using gradient-based methods...\n")
  
  # Alternative: Gradient-based feature importance
  model$eval()
  X_sample <- torch_tensor(X_test_scaled[1:10, ], dtype = torch_float32(), requires_grad = TRUE)$to(device = device)
  
  pred <- model(X_sample)
  # Use max CIF as output
  pred_probs <- nnf_softmax(pred, dim = 3)
  cif <- torch_cumsum(pred_probs, dim = 3)
  output <- torch_max(cif, dim = 3)[[1]][, 1]  # Risk 1, max CIF
  
  output$sum()$backward()
  
  gradients <- torch_mean(torch_abs(X_sample$grad), dim = 1)$cpu()
  gradients_array <- as.array(gradients)
  
  grad_df <- data.frame(
    feature = colnames(X_test),
    gradient_importance = gradients_array
  ) %>%
    arrange(desc(gradient_importance))
  
  top_grad <- head(grad_df, 12)
  
  ggplot(top_grad, aes(x = reorder(feature, gradient_importance), y = gradient_importance)) +
    geom_bar(stat = "identity", fill = "steelblue", color = "black") +
    coord_flip() +
    labs(x = "Feature", y = "Gradient-based Importance",
         title = "Gradient-based Feature Importance (Alternative to SHAP)") +
    theme_minimal() +
    theme(plot.title = element_text(size = 14, face = "bold"))
  
  cat("\nGradient-based Feature Importance:\n")
  print(head(grad_df, 10))
}


## 13. Summary and Conclusion

### Key Takeaways:

1. **DeepHit Advantages**:
   - Handles competing risks naturally
   - No proportional hazards assumption
   - Can capture non-linear relationships
   - Flexible architecture for complex survival data

2. **Model Performance**:
   - The model learns to predict cumulative incidence functions for each competing risk
   - Performance can be evaluated using C-index, Brier scores, and other metrics
   - Risk stratification helps identify high-risk patients

3. **Interpretability**:
   - Feature importance analysis reveals which covariates drive predictions
   - SHAP values provide local explanations for individual predictions
   - Risk stratification enables clinical decision-making

4. **Best Practices**:
   - Proper data preprocessing (standardization) is crucial
   - Hyperparameter tuning (learning rate, architecture, dropout) improves performance
   - Early stopping prevents overfitting
   - Cross-validation helps assess model generalizability

### Limitations and Future Directions:

- DeepHit requires sufficient data for training deep networks
- Hyperparameter selection can be time-consuming
- Model interpretation, while possible, is more complex than linear models
- Future work could explore attention mechanisms or transformer architectures

### Conclusion:

DeepHit represents a significant advancement in survival analysis, particularly for competing risks scenarios. Its ability to model complex relationships without restrictive assumptions makes it valuable for modern healthcare applications. However, careful validation and interpretation remain essential for clinical deployment.


## 14. Resources

### Papers and Publications:

1. **DeepHit Paper**:
   - Lee, C., et al. (2018). "DeepHit: A Deep Learning Approach to Survival Analysis with Competing Risks." AAAI 2018.
   - Link: http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit

### Code Repositories:

1. **Official DeepHit Repository**:
   - GitHub: https://github.com/chl8856/DeepHit
   - Contains original implementation and sample data

2. **PyCox Library (Python)**:
   - PyPI: https://pypi.org/project/pycox/
   - Provides PyTorch implementations of various survival models including DeepHit

3. **torch for R**:
   - CRAN: https://cran.r-project.org/package=torch
   - R interface to PyTorch for deep learning

### Datasets:

1. **Synthetic Competing Risks Data**:
   - Sample data from DeepHit repository
   - Link: https://raw.githubusercontent.com/chl8856/DeepHit/refs/heads/master/sample%20data/SYNTHETIC/synthetic_comprisk.csv

### Related Tools and Libraries:

1. **survival (R)**: Survival analysis in R
   - https://cran.r-project.org/package=survival

2. **fastshap (R)**: Fast SHAP values for R
   - https://cran.r-project.org/package=fastshap

3. **torch (R)**: Deep learning framework for R
   - https://torch.mlverse.org/

### Additional Reading:

1. Competing Risks Survival Analysis: Theory and Application
2. Deep Learning for Survival Analysis: A Review
3. Neural Networks for Survival Analysis

---

**Tutorial Created**: 2024
**Author**: Survival Analysis Tutorial Series
**License**: Educational Use
