diff --git a/NEWS.md b/NEWS.md
index b021e0e..1584892 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,5 +1,9 @@
# probably (development version)
+* Updated unit tests for new ggplot2 release (#180).
+
+* Better error message when using one of the `cal_validate_*()` functions with a validation set (#182).
+
# probably 1.1.0
* Significant refactoring of the code underlying the calibration functions. The user-facing APIs have not changed.
diff --git a/R/cal-validate.R b/R/cal-validate.R
index a7afe77..77fc804 100644
--- a/R/cal-validate.R
+++ b/R/cal-validate.R
@@ -68,6 +68,9 @@ cal_validate_logistic.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
+ cl <- match.call()
+ validation_check(.data, cl)
+
if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
@@ -151,6 +154,9 @@ cal_validate_isotonic.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
+ cl <- match.call()
+ validation_check(.data, cl)
+
if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
@@ -236,6 +242,9 @@ cal_validate_isotonic_boot.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
+ cl <- match.call()
+ validation_check(.data, cl)
+
if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
@@ -321,6 +330,12 @@ cal_validate_beta.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
+ cl <- match.call()
+ validation_check(.data, cl)
+
+ cl <- match.call()
+ validation_check(.data, cl)
+
if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
@@ -401,6 +416,9 @@ cal_validate_multinomial.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
+ cl <- match.call()
+ validation_check(.data, cl)
+
if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
@@ -515,6 +533,7 @@ cal_validate <- function(rset,
predictions_out <- pull_pred(rset, analysis = FALSE)
est_fn_name <- paste0("cal_estimate_", cal_function)
+
est_cl <-
rlang::call2(
est_fn_name,
@@ -560,18 +579,23 @@ cal_validate <- function(rset,
}
pull_pred <- function(x, analysis = TRUE) {
- has_dot_row <- any(names(x$splits[[1]]$data) == ".row")
if (analysis) {
what <- "analysis"
} else {
what <- "assessment"
}
- preds <- purrr::map(x$splits, as.data.frame, data = what)
- if (!has_dot_row) {
- rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what)))
- preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y))
- }
+ if (inherits(x$splits[[1]], "val_split")) {
+ preds <- as.data.frame(x$splits[[1]], what)
+ } else {
+ has_dot_row <- any(names(x$splits[[1]]$data) == ".row")
+
+ preds <- purrr::map(x$splits, as.data.frame, data = what)
+ if (!has_dot_row) {
+ rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what)))
+ preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y))
+ }
+ }
preds
}
@@ -655,6 +679,9 @@ cal_validate_linear.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
+ cl <- match.call()
+ validation_check(.data, cl)
+
if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
@@ -748,6 +775,9 @@ cal_validate_none.resample_results <-
metrics = NULL,
save_pred = FALSE,
...) {
+ cl <- match.call()
+ validation_check(.data, cl)
+
if (!is.null(truth)) {
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
}
@@ -803,8 +833,14 @@ convert_resamples <- function(x) {
predictions <-
tune::collect_predictions(x, summarize = TRUE) |>
dplyr::arrange(.row)
+
+ # Not all prediction sets, when collected, will match the size of the original
+ # data so buff out the data set
+ data_ind <- dplyr::tibble(.row = seq_len(nrow(x$splits[[1]]$data)))
+ all_data <- dplyr::full_join(data_ind, predictions, by = ".row")
+
for (i in seq_along(x$splits)) {
- x$splits[[i]]$data <- predictions
+ x$splits[[i]]$data <- all_data
}
class(x) <- c("rset", "tbl_df", "tbl", "data.frame")
x
@@ -891,3 +927,18 @@ collect_predictions.cal_rset <- function(x, summarize = TRUE, ...) {
}
res
}
+
+validation_check <- function(x, cl = NULL, call = rlang::caller_env()) {
+ fn <- as.character(cl[[1]])
+ fn <- strsplit(fn, "\\.")[[1]][1]
+
+ if (inherits(x$splits[[1]], "val_split")) {
+ cli::cli_abort(
+ "For validation sets, please make a resampling object from the predictions
+ prior to calling {.fn {fn}}",
+ call = call
+ )
+ }
+ invisible(NULL)
+}
+
diff --git a/man/int_conformal_full.Rd b/man/int_conformal_full.Rd
index 31680ec..e950fe7 100644
--- a/man/int_conformal_full.Rd
+++ b/man/int_conformal_full.Rd
@@ -111,28 +111,7 @@ intervals for the five new samples in parallel:
predict(lm_conform, new_dat)
}\if{html}{\out{}}
-\if{html}{\out{
}}\preformatted{## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may
-## not be available when loading
-
-## # A tibble: 5 x 2
+\if{html}{\out{
}}\preformatted{## # A tibble: 5 x 2
## .pred_lower .pred_upper
##
## 1 -17.9 59.6
diff --git a/tests/testthat/_snaps/cal-validate.md b/tests/testthat/_snaps/cal-validate.md
index 3330607..6e223de 100644
--- a/tests/testthat/_snaps/cal-validate.md
+++ b/tests/testthat/_snaps/cal-validate.md
@@ -44,3 +44,59 @@
This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column.
i Not an object.
+# validation sets fail with better message
+
+ Code
+ cal_validate_beta(mt_res)
+ Condition
+ Error in `cal_validate_beta()`:
+ ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_beta()`
+
+---
+
+ Code
+ cal_validate_isotonic(mt_res)
+ Condition
+ Error in `cal_validate_isotonic()`:
+ ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic()`
+
+---
+
+ Code
+ cal_validate_isotonic_boot(mt_res)
+ Condition
+ Error in `cal_validate_isotonic_boot()`:
+ ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic_boot()`
+
+---
+
+ Code
+ cal_validate_linear(mt_res)
+ Condition
+ Error in `cal_validate_linear()`:
+ ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_linear()`
+
+---
+
+ Code
+ cal_validate_logistic(mt_res)
+ Condition
+ Error in `cal_validate_logistic()`:
+ ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_logistic()`
+
+---
+
+ Code
+ cal_validate_multinomial(mt_res)
+ Condition
+ Error in `cal_validate_multinomial()`:
+ ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_multinomial()`
+
+---
+
+ Code
+ cal_validate_none(mt_res)
+ Condition
+ Error in `cal_validate_none()`:
+ ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_none()`
+
diff --git a/tests/testthat/test-cal-validate-multiclass.R b/tests/testthat/test-cal-validate-multiclass.R
index 9daed24..bc8c81b 100644
--- a/tests/testthat/test-cal-validate-multiclass.R
+++ b/tests/testthat/test-cal-validate-multiclass.R
@@ -50,8 +50,8 @@ test_that("Isotonic validation with `fit_resamples` - Multiclass", {
)
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
diff --git a/tests/testthat/test-cal-validate.R b/tests/testthat/test-cal-validate.R
index 005994f..31ea5a3 100644
--- a/tests/testthat/test-cal-validate.R
+++ b/tests/testthat/test-cal-validate.R
@@ -348,8 +348,8 @@ test_that("Logistic validation with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -380,8 +380,8 @@ test_that("Isotonic classification validation with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -413,8 +413,8 @@ test_that("Bootstrapped isotonic classification validation with `fit_resamples`"
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -446,8 +446,8 @@ test_that("Beta calibration validation with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -481,8 +481,8 @@ test_that("Multinomial calibration validation with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -513,8 +513,8 @@ test_that("Validation without calibration with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -548,8 +548,8 @@ test_that("Linear validation with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred", ".row", "outcome", ".config")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred", ".row", "outcome", ".config"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -621,8 +621,8 @@ test_that("Isotonic regression validation with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred", ".row", "outcome", ".config")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred", ".row", "outcome", ".config"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -657,8 +657,8 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", {
skip_if_not_installed("tune", "1.2.0")
expect_equal(
- names(val_with_pred$.predictions_cal[[1]]),
- c(".pred", ".row", "outcome", ".config")
+ sort(names(val_with_pred$.predictions_cal[[1]])),
+ sort(c(".pred", ".row", "outcome", ".config"))
)
expect_equal(
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -670,7 +670,6 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", {
# ------------------------------------------------------------------------------
-
test_that("validation functions error with tune_results input", {
skip_if_not_installed("modeldata")
skip_if_not_installed("nnet")
@@ -698,3 +697,27 @@ test_that("validation functions error with tune_results input", {
cal_validate_none(testthat_cal_binary())
)
})
+
+# ------------------------------------------------------------------------------
+
+test_that("validation sets fail with better message", {
+ library(tune)
+ set.seed(1)
+ mt_split <- rsample::initial_validation_split(mtcars)
+ mt_rset <- rsample::validation_set(mt_split)
+ mt_res <-
+ parsnip::linear_reg() |>
+ fit_resamples(
+ mpg ~ .,
+ resamples = mt_rset,
+ control = control_resamples(save_pred = TRUE)
+ )
+
+ expect_snapshot(cal_validate_beta(mt_res), error = TRUE)
+ expect_snapshot(cal_validate_isotonic(mt_res), error = TRUE)
+ expect_snapshot(cal_validate_isotonic_boot(mt_res), error = TRUE)
+ expect_snapshot(cal_validate_linear(mt_res), error = TRUE)
+ expect_snapshot(cal_validate_logistic(mt_res), error = TRUE)
+ expect_snapshot(cal_validate_multinomial(mt_res), error = TRUE)
+ expect_snapshot(cal_validate_none(mt_res), error = TRUE)
+})