Skip to content

Commit

Permalink
Merge pull request #471 from tidymodels/validate-all-the-same-eval_time
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Dec 16, 2023
2 parents 2285618 + dbf820c commit 2f75b6c
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 19 deletions.
7 changes: 2 additions & 5 deletions R/surv-brier_survival_integrated.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,8 @@ brier_survival_integrated_vec <- function(truth,
}

get_unique_eval_times <- function(x) {
res <- lapply(x, function(x) x$.eval_time)
res <- unlist(res)
res <- unique(res)
res <- length(res)
res
# Since validate_surv_truth_list_estimate() makes sure they are all the same
length(x[[1]]$.eval_time)
}

brier_survival_integrated_impl <- function(truth,
Expand Down
36 changes: 26 additions & 10 deletions R/validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,29 @@ validate_surv_truth_list_estimate <- function(truth,
)
}

all_eval_times_list <- lapply(estimate, function(x) x$.eval_time)
all_eval_times <- unlist(all_eval_times_list)
eval_time_cols <- lapply(estimate, function(x) x$.eval_time)

if (length(unique(eval_time_cols)) > 1) {
offenders <- vapply(
eval_time_cols,
function(x) !identical(x, eval_time_cols[[1]]),
logical(1)
)
offenders <- which(offenders)

if (any(is.na(all_eval_times))) {
cli::cli_abort(
c(
x = "All the {.field .eval_time} columns of {.arg estimate} must be \\
identical.",
i = "The folllowing index differed from the first: {.val {offenders}}."
),
call = call
)
}

eval_time <- eval_time_cols[[1]]

if (any(is.na(eval_time))) {
cli::cli_abort(
c(
x = "Missing values in {.field .eval_time} are not allowed."
Expand All @@ -216,8 +235,8 @@ validate_surv_truth_list_estimate <- function(truth,
)
}

if (any(all_eval_times < 0)) {
offenders <- unique(all_eval_times[all_eval_times < 0])
if (any(eval_time < 0)) {
offenders <- unique(eval_time[eval_time < 0])

cli::cli_abort(
c(
Expand All @@ -228,7 +247,7 @@ validate_surv_truth_list_estimate <- function(truth,
)
}

if (any(is.infinite(all_eval_times))) {
if (any(is.infinite(eval_time))) {
cli::cli_abort(
c(
x = "Infinite values of {.field .eval_time} are not allowed."
Expand All @@ -237,10 +256,7 @@ validate_surv_truth_list_estimate <- function(truth,
)
}

any_duplicates <- any(
vapply(all_eval_times_list, function(x) any(table(x) > 1), logical(1))
)
if (any_duplicates) {
if (any(duplicated(eval_time))) {
cli::cli_abort(
c(
x = "Duplicate values of {.field .eval_time} are not allowed."
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/_snaps/validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@
Error:
! `estimate` should be a list, not a a double vector.

---

Code
validate_surv_truth_list_estimate(lung_surv_not_all_same$surv_obj,
lung_surv_not_all_same$.pred)
Condition
Error:
x All the .eval_time columns of `estimate` must be identical.
i The folllowing index differed from the first: 5, 10, and 14.

---

Code
Expand Down
21 changes: 17 additions & 4 deletions tests/testthat/test-validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,21 @@ test_that("validate_surv_truth_list_estimate errors as expected", {
)
)

lung_surv_neg <- lung_surv
lung_surv_not_all_same <- lung_surv
lung_surv_not_all_same$.pred[[5]]$.eval_time[1] <- 350
lung_surv_not_all_same$.pred[[10]]$.eval_time[1] <- 350
lung_surv_not_all_same$.pred[[14]]$.eval_time[1] <- 350
expect_snapshot(
error = TRUE,
validate_surv_truth_list_estimate(
lung_surv_not_all_same$surv_obj,
lung_surv_not_all_same$.pred
)
)

lung_surv_neg <- lung_surv[1, ]
lung_surv_neg$.pred[[1]]$.eval_time[1] <- -100
rep()
expect_snapshot(
error = TRUE,
validate_surv_truth_list_estimate(
Expand All @@ -398,7 +411,7 @@ test_that("validate_surv_truth_list_estimate errors as expected", {
)
)

lung_surv_na <- lung_surv
lung_surv_na <- lung_surv[1, ]
lung_surv_na$.pred[[1]]$.eval_time[1] <- NA
expect_snapshot(
error = TRUE,
Expand All @@ -408,7 +421,7 @@ test_that("validate_surv_truth_list_estimate errors as expected", {
)
)

lung_surv_inf <- lung_surv
lung_surv_inf <- lung_surv[1, ]
lung_surv_inf$.pred[[1]]$.eval_time[1] <- Inf
expect_snapshot(
error = TRUE,
Expand All @@ -418,7 +431,7 @@ test_that("validate_surv_truth_list_estimate errors as expected", {
)
)

lung_surv_duplicate <- lung_surv
lung_surv_duplicate <- lung_surv[1, ]
lung_surv_duplicate$.pred[[1]]$.eval_time[1] <- 200
expect_snapshot(
error = TRUE,
Expand Down

0 comments on commit 2f75b6c

Please sign in to comment.