Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# probably (development version)

* Updated unit tests for new ggplot2 release (#180).

* Better error message when using one of the `cal_validate_*()` functions with a validation set (#182).

# probably 1.1.0

* Significant refactoring of the code underlying the calibration functions. The user-facing APIs have not changed.
Expand Down
65 changes: 58 additions & 7 deletions R/cal-validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ cal_validate_logistic.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
cl <- match.call()
validation_check(.data, cl)

if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
Expand Down Expand Up @@ -151,6 +154,9 @@ cal_validate_isotonic.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
cl <- match.call()
validation_check(.data, cl)

if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
Expand Down Expand Up @@ -236,6 +242,9 @@ cal_validate_isotonic_boot.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
cl <- match.call()
validation_check(.data, cl)

if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
Expand Down Expand Up @@ -321,6 +330,12 @@ cal_validate_beta.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
cl <- match.call()
validation_check(.data, cl)

cl <- match.call()
validation_check(.data, cl)

if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
Expand Down Expand Up @@ -401,6 +416,9 @@ cal_validate_multinomial.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
cl <- match.call()
validation_check(.data, cl)

if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
Expand Down Expand Up @@ -515,6 +533,7 @@ cal_validate <- function(rset,
predictions_out <- pull_pred(rset, analysis = FALSE)

est_fn_name <- paste0("cal_estimate_", cal_function)

est_cl <-
rlang::call2(
est_fn_name,
Expand Down Expand Up @@ -560,18 +579,23 @@ cal_validate <- function(rset,
}

pull_pred <- function(x, analysis = TRUE) {
has_dot_row <- any(names(x$splits[[1]]$data) == ".row")
if (analysis) {
what <- "analysis"
} else {
what <- "assessment"
}
preds <- purrr::map(x$splits, as.data.frame, data = what)
if (!has_dot_row) {
rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what)))
preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y))
}

if (inherits(x$splits[[1]], "val_split")) {
preds <- as.data.frame(x$splits[[1]], what)
} else {
has_dot_row <- any(names(x$splits[[1]]$data) == ".row")

preds <- purrr::map(x$splits, as.data.frame, data = what)
if (!has_dot_row) {
rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what)))
preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y))
}
}
preds
}

Expand Down Expand Up @@ -655,6 +679,9 @@ cal_validate_linear.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
cl <- match.call()
validation_check(.data, cl)

if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
Expand Down Expand Up @@ -748,6 +775,9 @@ cal_validate_none.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
cl <- match.call()
validation_check(.data, cl)

if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
Expand Down Expand Up @@ -803,8 +833,14 @@ convert_resamples <- function(x) {
predictions <-
tune::collect_predictions(x, summarize = TRUE) |>
dplyr::arrange(.row)

# Not all prediction sets, when collected, will match the size of the original
# data so buff out the data set
data_ind <- dplyr::tibble(.row = seq_len(nrow(x$splits[[1]]$data)))
all_data <- dplyr::full_join(data_ind, predictions, by = ".row")

for (i in seq_along(x$splits)) {
x$splits[[i]]$data <- predictions
x$splits[[i]]$data <- all_data
}
class(x) <- c("rset", "tbl_df", "tbl", "data.frame")
x
Expand Down Expand Up @@ -891,3 +927,18 @@ collect_predictions.cal_rset <- function(x, summarize = TRUE, ...) {
}
res
}

validation_check <- function(x, cl = NULL, call = rlang::caller_env()) {
fn <- as.character(cl[[1]])
fn <- strsplit(fn, "\\.")[[1]][1]

if (inherits(x$splits[[1]], "val_split")) {
cli::cli_abort(
"For validation sets, please make a resampling object from the predictions
prior to calling {.fn {fn}}",
call = call
)
}
invisible(NULL)
}

23 changes: 1 addition & 22 deletions man/int_conformal_full.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions tests/testthat/_snaps/cal-validate.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,59 @@
This function can only be used with an <rset> object or the results of `tune::fit_resamples()` with a .predictions column.
i Not an <tune_results> object.

# validation sets fail with better message

Code
cal_validate_beta(mt_res)
Condition
Error in `cal_validate_beta()`:
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_beta()`

---

Code
cal_validate_isotonic(mt_res)
Condition
Error in `cal_validate_isotonic()`:
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic()`

---

Code
cal_validate_isotonic_boot(mt_res)
Condition
Error in `cal_validate_isotonic_boot()`:
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic_boot()`

---

Code
cal_validate_linear(mt_res)
Condition
Error in `cal_validate_linear()`:
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_linear()`

---

Code
cal_validate_logistic(mt_res)
Condition
Error in `cal_validate_logistic()`:
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_logistic()`

---

Code
cal_validate_multinomial(mt_res)
Condition
Error in `cal_validate_multinomial()`:
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_multinomial()`

---

Code
cal_validate_none(mt_res)
Condition
Error in `cal_validate_none()`:
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_none()`

4 changes: 2 additions & 2 deletions tests/testthat/test-cal-validate-multiclass.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ test_that("Isotonic validation with `fit_resamples` - Multiclass", {
)
skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down
61 changes: 42 additions & 19 deletions tests/testthat/test-cal-validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ test_that("Logistic validation with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -380,8 +380,8 @@ test_that("Isotonic classification validation with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -413,8 +413,8 @@ test_that("Bootstrapped isotonic classification validation with `fit_resamples`"

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -446,8 +446,8 @@ test_that("Beta calibration validation with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -481,8 +481,8 @@ test_that("Multinomial calibration validation with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -513,8 +513,8 @@ test_that("Validation without calibration with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -548,8 +548,8 @@ test_that("Linear validation with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred", ".row", "outcome", ".config")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred", ".row", "outcome", ".config"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -621,8 +621,8 @@ test_that("Isotonic regression validation with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred", ".row", "outcome", ".config")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred", ".row", "outcome", ".config"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand Down Expand Up @@ -657,8 +657,8 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", {

skip_if_not_installed("tune", "1.2.0")
expect_equal(
names(val_with_pred$.predictions_cal[[1]]),
c(".pred", ".row", "outcome", ".config")
sort(names(val_with_pred$.predictions_cal[[1]])),
sort(c(".pred", ".row", "outcome", ".config"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
Expand All @@ -670,7 +670,6 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", {

# ------------------------------------------------------------------------------


test_that("validation functions error with tune_results input", {
skip_if_not_installed("modeldata")
skip_if_not_installed("nnet")
Expand Down Expand Up @@ -698,3 +697,27 @@ test_that("validation functions error with tune_results input", {
cal_validate_none(testthat_cal_binary())
)
})

# ------------------------------------------------------------------------------

test_that("validation sets fail with better message", {
library(tune)
set.seed(1)
mt_split <- rsample::initial_validation_split(mtcars)
mt_rset <- rsample::validation_set(mt_split)
mt_res <-
parsnip::linear_reg() |>
fit_resamples(
mpg ~ .,
resamples = mt_rset,
control = control_resamples(save_pred = TRUE)
)

expect_snapshot(cal_validate_beta(mt_res), error = TRUE)
expect_snapshot(cal_validate_isotonic(mt_res), error = TRUE)
expect_snapshot(cal_validate_isotonic_boot(mt_res), error = TRUE)
expect_snapshot(cal_validate_linear(mt_res), error = TRUE)
expect_snapshot(cal_validate_logistic(mt_res), error = TRUE)
expect_snapshot(cal_validate_multinomial(mt_res), error = TRUE)
expect_snapshot(cal_validate_none(mt_res), error = TRUE)
})
Loading