Skip to content

Commit

Permalink
Merge pull request #497 from tidymodels/roc-curve-fix
Browse files Browse the repository at this point in the history
Use weights correctly in roc_curve_survival()
  • Loading branch information
EmilHvitfeldt committed Mar 21, 2024
2 parents 5d2e17f + 074a79d commit 5b01a3a
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 51 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ jobs:
- {os: macos-latest, r: 'release'}

- {os: windows-latest, r: 'release'}
# Use 3.6 to trigger usage of RTools35
- {os: windows-latest, r: '3.6'}
# use 4.1 to check with rtools40's older compiler
- {os: windows-latest, r: '4.1'}

Expand All @@ -35,7 +33,6 @@ jobs:
- {os: ubuntu-latest, r: 'oldrel-1'}
- {os: ubuntu-latest, r: 'oldrel-2'}
- {os: ubuntu-latest, r: 'oldrel-3'}
- {os: ubuntu-latest, r: 'oldrel-4'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# yardstick (development version)

* Bug was fixed in `roc_curve_survival()` where wrong weights were used. (#495, @asb2111).

# yardstick 1.3.0

## New Metrics
Expand Down
28 changes: 14 additions & 14 deletions R/surv-roc_curve_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,42 +179,42 @@ roc_curve_survival_impl <- function(truth,
}

roc_curve_survival_impl_one <- function(event_time, delta, data, case_weights) {
res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf))))

res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf)), decreasing = TRUE))
obs_time_le_time <- event_time <= data$.eval_time
obs_time_gt_time <- event_time > data$.eval_time
n <- nrow(data)
multiplier <- delta / (n * data$.weight_censored)

sensitivity_denom <- sum(obs_time_le_time * multiplier * case_weights, na.rm = TRUE)
specificity_denom <- sum(obs_time_gt_time * case_weights, na.rm = TRUE)


sensitivity_denom <- sum(obs_time_le_time * delta * data$.weight_censored * case_weights, na.rm = TRUE)
specificity_denom <- sum(obs_time_gt_time * data$.weight_censored * case_weights, na.rm = TRUE)

data_df <- data.frame(
le_time = obs_time_le_time,
ge_time = obs_time_gt_time,
multiplier = multiplier,
delta = delta,
weight_censored = data$.weight_censored,
case_weights = case_weights
)

data_split <- vec_split(data_df, data$.pred_survival)
data_split <- data_split$val[order(data_split$key)]

sensitivity <- vapply(
data_split,
function(x) sum(x$le_time * x$multiplier * x$case_weights, na.rm = TRUE),
function(x) sum(x$le_time * x$delta * x$weight_censored * x$case_weights, na.rm = TRUE),
FUN.VALUE = numeric(1)
)

sensitivity <- cumsum(sensitivity)
sensitivity <- sensitivity / sensitivity_denom
sensitivity <- dplyr::if_else(sensitivity > 1, 1, sensitivity)
sensitivity <- dplyr::if_else(sensitivity < 0, 0, sensitivity)
sensitivity <- c(0, sensitivity, 1)
res$sensitivity <- sensitivity

specificity <- vapply(
data_split,
function(x) sum(x$ge_time * x$case_weights, na.rm = TRUE),
function(x) sum(x$ge_time * x$weight_censored * x$case_weights, na.rm = TRUE),
FUN.VALUE = numeric(1)
)
specificity <- cumsum(specificity)
Expand Down
Binary file modified tests/testthat/data/ref_roc_auc_survival.rds
Binary file not shown.
Binary file modified tests/testthat/data/ref_roc_curve_survival.rds
Binary file not shown.
98 changes: 64 additions & 34 deletions tests/testthat/test-surv-roc_curve_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ test_that("roc_curve_survival works", {
exp_threshold <- tidyr::unnest(lung_surv, cols = .pred)
exp_threshold <- dplyr::filter(exp_threshold, .eval_time == eval_time)
exp_threshold <- exp_threshold$.pred_survival
exp_threshold <- sort(exp_threshold)
exp_threshold <- sort(exp_threshold, decreasing = TRUE)
exp_threshold <- unique(exp_threshold)
exp_threshold <- c(-Inf, exp_threshold, Inf)
exp_threshold <- c(Inf, exp_threshold, -Inf)
expect_identical(
result_tmp$.threshold,
exp_threshold
Expand Down Expand Up @@ -101,37 +101,51 @@ test_that("hand calculated equivalent", {
dplyr::filter(.eval_time == my_eval_time)

thresholds <- sort(lung_surv0$.pred_survival)
thresholds <- thresholds[c(1, 10, 100, 200, nrow(lung_surv0))]

# Sensitivity
calc_sensitivity <- function(threshold, data, eval_time) {
delta <- .extract_surv_status(data$surv_obj)
event_time <- .extract_surv_time(data$surv_obj)
res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf)), decreasing = TRUE))

obs_time_le_time <- event_time <= data$.eval_time
obs_time_gt_time <- event_time > data$.eval_time
n <- nrow(data)
event_time <- yardstick:::.extract_surv_time(data$surv_obj)
delta <- yardstick:::.extract_surv_status(data$surv_obj)
obs_time_le_time <- event_time <= eval_time
prob_le_thresh <- data$.pred_survival <= threshold

multiplier <- delta / (n * data$.weight_censored)
numer <- sum(obs_time_le_time * prob_le_thresh * multiplier, na.rm = TRUE)
denom <- sum(obs_time_le_time * multiplier, na.rm = TRUE)
numer / denom

sensitivity_denom <- sum(obs_time_le_time * delta * data$.weight_censored, na.rm = TRUE)

data_df <- data.frame(
le_time = obs_time_le_time,
delta = delta,
weight_censored = data$.weight_censored
)

data_split <- vec_split(data_df, data$.pred_survival)
data_split <- data_split$val[order(data_split$key)]

sensitivity <- vapply(
data_split,
function(x) sum(x$le_time * x$delta * x$weight_censored, na.rm = TRUE),
FUN.VALUE = numeric(1)
)

sensitivity <- cumsum(sensitivity)
sensitivity <- sensitivity / sensitivity_denom
sensitivity <- dplyr::if_else(sensitivity > 1, 1, sensitivity)
sensitivity <- dplyr::if_else(sensitivity < 0, 0, sensitivity)
sensitivity <- c(0, sensitivity, 1)
sensitivity
}

exp_sens <- vapply(
thresholds,
calc_sensitivity,
data = lung_surv0,
eval_time = my_eval_time,
FUN.VALUE = numeric(1)
)
exp_sens <- calc_sensitivity(thresholds, lung_surv0, my_eval_time)

yardstick_res <- lung_surv %>%
dplyr::slice(-14) %>%
roc_curve_survival(
truth = surv_obj,
.pred
) %>%
dplyr::filter(.threshold %in% thresholds)
dplyr::filter(.eval_time == my_eval_time)

expect_equal(
yardstick_res$sensitivity,
Expand All @@ -141,29 +155,45 @@ test_that("hand calculated equivalent", {
# specificity
calc_specificity <- function(threshold, data, eval_time) {
event_time <- yardstick:::.extract_surv_time(data$surv_obj)
delta <- yardstick:::.extract_surv_status(data$surv_obj)
obs_time_gt_time <- event_time > eval_time
prob_le_thresh <- data$.pred_survival > threshold
numer <- sum(obs_time_gt_time * prob_le_thresh, na.rm = TRUE)
denom <- sum(obs_time_gt_time, na.rm = TRUE)
numer / denom

res <- dplyr::tibble(.threshold = sort(unique(c(-Inf, data$.pred_survival, Inf)), decreasing = TRUE))

obs_time_gt_time <- event_time > data$.eval_time
n <- nrow(data)

specificity_denom <- sum(obs_time_gt_time * data$.weight_censored, na.rm = TRUE)

data_df <- data.frame(
ge_time = obs_time_gt_time,
weight_censored = data$.weight_censored
)

data_split <- vec_split(data_df, data$.pred_survival)
data_split <- data_split$val[order(data_split$key)]

specificity <- vapply(
data_split,
function(x) sum(x$ge_time * x$weight_censored, na.rm = TRUE),
FUN.VALUE = numeric(1)
)
specificity <- cumsum(specificity)
specificity <- specificity / specificity_denom
specificity <- dplyr::if_else(specificity > 1, 1, specificity)
specificity <- dplyr::if_else(specificity < 0, 0, specificity)
specificity <- c(0, specificity, 1)
specificity <- 1 - specificity
specificity
}

exp_spec <- vapply(
thresholds,
calc_specificity,
data = lung_surv0,
eval_time = my_eval_time,
FUN.VALUE = numeric(1)
)
exp_spec <- calc_specificity(thresholds, lung_surv0, my_eval_time)

yardstick_res <- lung_surv %>%
dplyr::slice(-14) %>%
roc_curve_survival(
truth = surv_obj,
.pred
) %>%
dplyr::filter(.threshold %in% thresholds)
dplyr::filter(.eval_time == my_eval_time)

expect_equal(
yardstick_res$specificity,
Expand Down

0 comments on commit 5b01a3a

Please sign in to comment.