Skip to content

Commit

Permalink
Merge pull request #467 from tidymodels/fix-460
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Dec 13, 2023
2 parents 6c98997 + 16742d7 commit 97adb7a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ calculated with `roc_auc_survival()`.

* All warnings and errors have been updated to use the cli package for increased clarity and consistency. (#456, #457, #458)

* `brier_survival_integrated()` now throws an error if input data only includes 1 evalution time point. (#460)

# yardstick 1.2.0

## New Metrics
Expand Down
16 changes: 16 additions & 0 deletions R/surv-brier_survival_integrated.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ brier_survival_integrated_vec <- function(truth,
truth, estimate, case_weights
)

num_eval_times <- get_unique_eval_times(estimate)
if (num_eval_times < 2) {
cli::cli_abort(
"At least 2 evaluation time{?s} {?is/are} required. \\
Only {num_eval_times} unique time{?s} {?was/were} given."
)
}

if (na_rm) {
result <- yardstick_remove_missing(
truth, seq_along(estimate), case_weights
Expand All @@ -111,6 +119,14 @@ brier_survival_integrated_vec <- function(truth,
brier_survival_integrated_impl(truth, estimate, case_weights)
}

get_unique_eval_times <- function(x) {
res <- lapply(x, function(x) x$.eval_time)
res <- unlist(res)
res <- unique(res)
res <- length(res)
res
}

brier_survival_integrated_impl <- function(truth,
estimate,
case_weights) {
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/surv-brier_survival_integrated.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# brier_survival_integrated calculations

Code
brier_survival_integrated(data = lung_surv, truth = surv_obj, .pred)
Condition
Error in `brier_survival_integrated()`:
! At least 2 evaluation time is required. Only 1 unique time was given.

15 changes: 15 additions & 0 deletions tests/testthat/test-surv-brier_survival_integrated.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ test_that("brier_survival_integrated calculations", {
)
})

test_that("brier_survival_integrated calculations", {
lung_surv <- data_lung_surv()

lung_surv$.pred <- lapply(lung_surv$.pred, function(x) x[1, ])

expect_snapshot(
error = TRUE,
brier_survival_integrated(
data = lung_surv,
truth = surv_obj,
.pred
)
)
})

test_that("case weights", {
lung_surv <- data_lung_surv()
lung_surv$case_wts <- seq_len(nrow(lung_surv))
Expand Down

0 comments on commit 97adb7a

Please sign in to comment.