From 1531e035a30968f2a052f55a7ba7fec9aaa38bff Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 30 Jul 2025 09:41:29 -0400 Subject: [PATCH] changes for #182 --- NEWS.md | 4 ++ R/cal-validate.R | 65 +++++++++++++++++-- man/int_conformal_full.Rd | 23 +------ tests/testthat/_snaps/cal-validate.md | 56 ++++++++++++++++ tests/testthat/test-cal-validate-multiclass.R | 4 +- tests/testthat/test-cal-validate.R | 61 +++++++++++------ 6 files changed, 163 insertions(+), 50 deletions(-) diff --git a/NEWS.md b/NEWS.md index b021e0e0..15848925 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # probably (development version) +* Updated unit tests for new ggplot2 release (#180). + +* Better error message when using one of the `cal_validate_*()` functions with a validation set (#182). + # probably 1.1.0 * Significant refactoring of the code underlying the calibration functions. The user-facing APIs have not changed. diff --git a/R/cal-validate.R b/R/cal-validate.R index a7afe774..77fc804b 100644 --- a/R/cal-validate.R +++ b/R/cal-validate.R @@ -68,6 +68,9 @@ cal_validate_logistic.resample_results <- metrics = NULL, save_pred = FALSE, ...) { + cl <- match.call() + validation_check(.data, cl) + if (!is.null(truth)) { cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") } @@ -151,6 +154,9 @@ cal_validate_isotonic.resample_results <- metrics = NULL, save_pred = FALSE, ...) { + cl <- match.call() + validation_check(.data, cl) + if (!is.null(truth)) { cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") } @@ -236,6 +242,9 @@ cal_validate_isotonic_boot.resample_results <- metrics = NULL, save_pred = FALSE, ...) { + cl <- match.call() + validation_check(.data, cl) + if (!is.null(truth)) { cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") } @@ -321,6 +330,12 @@ cal_validate_beta.resample_results <- metrics = NULL, save_pred = FALSE, ...) { + cl <- match.call() + validation_check(.data, cl) + + cl <- match.call() + validation_check(.data, cl) + if (!is.null(truth)) { cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") } @@ -401,6 +416,9 @@ cal_validate_multinomial.resample_results <- metrics = NULL, save_pred = FALSE, ...) { + cl <- match.call() + validation_check(.data, cl) + if (!is.null(truth)) { cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") } @@ -515,6 +533,7 @@ cal_validate <- function(rset, predictions_out <- pull_pred(rset, analysis = FALSE) est_fn_name <- paste0("cal_estimate_", cal_function) + est_cl <- rlang::call2( est_fn_name, @@ -560,18 +579,23 @@ cal_validate <- function(rset, } pull_pred <- function(x, analysis = TRUE) { - has_dot_row <- any(names(x$splits[[1]]$data) == ".row") if (analysis) { what <- "analysis" } else { what <- "assessment" } - preds <- purrr::map(x$splits, as.data.frame, data = what) - if (!has_dot_row) { - rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what))) - preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y)) - } + if (inherits(x$splits[[1]], "val_split")) { + preds <- as.data.frame(x$splits[[1]], what) + } else { + has_dot_row <- any(names(x$splits[[1]]$data) == ".row") + + preds <- purrr::map(x$splits, as.data.frame, data = what) + if (!has_dot_row) { + rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what))) + preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y)) + } + } preds } @@ -655,6 +679,9 @@ cal_validate_linear.resample_results <- metrics = NULL, save_pred = FALSE, ...) { + cl <- match.call() + validation_check(.data, cl) + if (!is.null(truth)) { cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") } @@ -748,6 +775,9 @@ cal_validate_none.resample_results <- metrics = NULL, save_pred = FALSE, ...) { + cl <- match.call() + validation_check(.data, cl) + if (!is.null(truth)) { cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") } @@ -803,8 +833,14 @@ convert_resamples <- function(x) { predictions <- tune::collect_predictions(x, summarize = TRUE) |> dplyr::arrange(.row) + + # Not all prediction sets, when collected, will match the size of the original + # data so buff out the data set + data_ind <- dplyr::tibble(.row = seq_len(nrow(x$splits[[1]]$data))) + all_data <- dplyr::full_join(data_ind, predictions, by = ".row") + for (i in seq_along(x$splits)) { - x$splits[[i]]$data <- predictions + x$splits[[i]]$data <- all_data } class(x) <- c("rset", "tbl_df", "tbl", "data.frame") x @@ -891,3 +927,18 @@ collect_predictions.cal_rset <- function(x, summarize = TRUE, ...) { } res } + +validation_check <- function(x, cl = NULL, call = rlang::caller_env()) { + fn <- as.character(cl[[1]]) + fn <- strsplit(fn, "\\.")[[1]][1] + + if (inherits(x$splits[[1]], "val_split")) { + cli::cli_abort( + "For validation sets, please make a resampling object from the predictions + prior to calling {.fn {fn}}", + call = call + ) + } + invisible(NULL) +} + diff --git a/man/int_conformal_full.Rd b/man/int_conformal_full.Rd index 31680ec1..e950fe7e 100644 --- a/man/int_conformal_full.Rd +++ b/man/int_conformal_full.Rd @@ -111,28 +111,7 @@ intervals for the five new samples in parallel: predict(lm_conform, new_dat) }\if{html}{\out{}} -\if{html}{\out{
}}\preformatted{## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading -## Warning in serialize(data, node$con, xdr = FALSE): 'package:workflowsets' may -## not be available when loading - -## # A tibble: 5 x 2 +\if{html}{\out{
}}\preformatted{## # A tibble: 5 x 2 ## .pred_lower .pred_upper ## ## 1 -17.9 59.6 diff --git a/tests/testthat/_snaps/cal-validate.md b/tests/testthat/_snaps/cal-validate.md index 33306071..6e223def 100644 --- a/tests/testthat/_snaps/cal-validate.md +++ b/tests/testthat/_snaps/cal-validate.md @@ -44,3 +44,59 @@ This function can only be used with an object or the results of `tune::fit_resamples()` with a .predictions column. i Not an object. +# validation sets fail with better message + + Code + cal_validate_beta(mt_res) + Condition + Error in `cal_validate_beta()`: + ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_beta()` + +--- + + Code + cal_validate_isotonic(mt_res) + Condition + Error in `cal_validate_isotonic()`: + ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic()` + +--- + + Code + cal_validate_isotonic_boot(mt_res) + Condition + Error in `cal_validate_isotonic_boot()`: + ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic_boot()` + +--- + + Code + cal_validate_linear(mt_res) + Condition + Error in `cal_validate_linear()`: + ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_linear()` + +--- + + Code + cal_validate_logistic(mt_res) + Condition + Error in `cal_validate_logistic()`: + ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_logistic()` + +--- + + Code + cal_validate_multinomial(mt_res) + Condition + Error in `cal_validate_multinomial()`: + ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_multinomial()` + +--- + + Code + cal_validate_none(mt_res) + Condition + Error in `cal_validate_none()`: + ! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_none()` + diff --git a/tests/testthat/test-cal-validate-multiclass.R b/tests/testthat/test-cal-validate-multiclass.R index 9daed243..bc8c81bf 100644 --- a/tests/testthat/test-cal-validate-multiclass.R +++ b/tests/testthat/test-cal-validate-multiclass.R @@ -50,8 +50,8 @@ test_that("Isotonic validation with `fit_resamples` - Multiclass", { ) skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), diff --git a/tests/testthat/test-cal-validate.R b/tests/testthat/test-cal-validate.R index 005994f7..31ea5a32 100644 --- a/tests/testthat/test-cal-validate.R +++ b/tests/testthat/test-cal-validate.R @@ -348,8 +348,8 @@ test_that("Logistic validation with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -380,8 +380,8 @@ test_that("Isotonic classification validation with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -413,8 +413,8 @@ test_that("Bootstrapped isotonic classification validation with `fit_resamples`" skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -446,8 +446,8 @@ test_that("Beta calibration validation with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -481,8 +481,8 @@ test_that("Multinomial calibration validation with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -513,8 +513,8 @@ test_that("Validation without calibration with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -548,8 +548,8 @@ test_that("Linear validation with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred", ".row", "outcome", ".config") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred", ".row", "outcome", ".config")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -621,8 +621,8 @@ test_that("Isotonic regression validation with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred", ".row", "outcome", ".config") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred", ".row", "outcome", ".config")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -657,8 +657,8 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", { skip_if_not_installed("tune", "1.2.0") expect_equal( - names(val_with_pred$.predictions_cal[[1]]), - c(".pred", ".row", "outcome", ".config") + sort(names(val_with_pred$.predictions_cal[[1]])), + sort(c(".pred", ".row", "outcome", ".config")) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -670,7 +670,6 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", { # ------------------------------------------------------------------------------ - test_that("validation functions error with tune_results input", { skip_if_not_installed("modeldata") skip_if_not_installed("nnet") @@ -698,3 +697,27 @@ test_that("validation functions error with tune_results input", { cal_validate_none(testthat_cal_binary()) ) }) + +# ------------------------------------------------------------------------------ + +test_that("validation sets fail with better message", { + library(tune) + set.seed(1) + mt_split <- rsample::initial_validation_split(mtcars) + mt_rset <- rsample::validation_set(mt_split) + mt_res <- + parsnip::linear_reg() |> + fit_resamples( + mpg ~ ., + resamples = mt_rset, + control = control_resamples(save_pred = TRUE) + ) + + expect_snapshot(cal_validate_beta(mt_res), error = TRUE) + expect_snapshot(cal_validate_isotonic(mt_res), error = TRUE) + expect_snapshot(cal_validate_isotonic_boot(mt_res), error = TRUE) + expect_snapshot(cal_validate_linear(mt_res), error = TRUE) + expect_snapshot(cal_validate_logistic(mt_res), error = TRUE) + expect_snapshot(cal_validate_multinomial(mt_res), error = TRUE) + expect_snapshot(cal_validate_none(mt_res), error = TRUE) +})