From b537070ac404d2e7cfb14a0af74f4a9b59e47175 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 8 Oct 2025 11:00:54 -0700 Subject: [PATCH 1/6] pull out cal apply tests to their own files --- tests/testthat/test-cal-apply-binary.R | 41 +++++++++++ tests/testthat/test-cal-apply-multi.R | 43 +++++++++++ tests/testthat/test-cal-apply-regression.R | 37 ++++++++++ tests/testthat/test-cal-apply.R | 84 ---------------------- 4 files changed, 121 insertions(+), 84 deletions(-) create mode 100644 tests/testthat/test-cal-apply-binary.R create mode 100644 tests/testthat/test-cal-apply-multi.R create mode 100644 tests/testthat/test-cal-apply-regression.R diff --git a/tests/testthat/test-cal-apply-binary.R b/tests/testthat/test-cal-apply-binary.R new file mode 100644 index 0000000..5dcfdd6 --- /dev/null +++ b/tests/testthat/test-cal-apply-binary.R @@ -0,0 +1,41 @@ +test_that("Logistic apply works - data.frame", { + sl_logistic <- cal_estimate_logistic(segment_logistic, Class, smooth = FALSE) + ap_logistic <- cal_apply(segment_logistic, sl_logistic) + + pred_good <- ap_logistic$.pred_good + expect_equal(mean(pred_good), 0.3425743, tolerance = 0.000001) + expect_equal(sd(pred_good), 0.2993934, tolerance = 0.000001) +}) + +test_that("Logistic apply works - tune_results", { + skip_if_not_installed("modeldata") + + tct <- testthat_cal_binary() + tl_logistic <- cal_estimate_logistic(tct, smooth = FALSE) + tap_logistic <- cal_apply(tct, tl_logistic) + expect_equal( + testthat_cal_binary_count(), + nrow(tap_logistic) + ) +}) + +test_that("Logistic spline apply works", { + sl_gam <- cal_estimate_logistic(segment_logistic, Class) + ap_gam <- cal_apply(segment_logistic, sl_gam) + + pred_good <- ap_gam$.pred_good + expect_equal(mean(pred_good), 0.3425743, tolerance = 0.000001) + expect_equal(sd(pred_good), 0.2987027, tolerance = 0.000001) +}) + +test_that("Logistic spline apply works - tune_results", { + skip_if_not_installed("modeldata") + + tct <- testthat_cal_binary() + tl_gam <- cal_estimate_logistic(tct) + tap_gam <- cal_apply(tct, tl_gam) + expect_equal( + testthat_cal_binary_count(), + nrow(tap_gam) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-cal-apply-multi.R b/tests/testthat/test-cal-apply-multi.R new file mode 100644 index 0000000..4a4d208 --- /dev/null +++ b/tests/testthat/test-cal-apply-multi.R @@ -0,0 +1,43 @@ +test_that("Multinomial apply works - data.frame", { + sl_multinomial <- cal_estimate_multinomial(species_probs, Species, smooth = FALSE) + ap_multinomial <- cal_apply(species_probs, sl_multinomial) + + pred_bobcat <- ap_multinomial$.pred_bobcat + expect_equal(mean(pred_bobcat), 0.5181842, tolerance = 0.000001) + expect_equal(sd(pred_bobcat), 0.3264982, tolerance = 0.000001) +}) + +test_that("Logistic apply works - tune_results", { + skip_if_not_installed("modeldata") + + tct <- testthat_cal_multiclass() + tl_multinomial <- cal_estimate_multinomial(tct, smooth = FALSE) + tap_multinomial <- cal_apply(tct, tl_multinomial) + + expect_equal( + testthat_cal_multiclass_count(), + nrow(tap_multinomial) + ) +}) + +test_that("Multinomial spline apply works", { + sl_gam <- cal_estimate_multinomial(species_probs, Species) + ap_gam <- cal_apply(species_probs, sl_gam) + + pred_bobcat <- ap_gam$.pred_bobcat + expect_equal(mean(pred_bobcat), 0.5181818, tolerance = 0.000001) + expect_equal(sd(pred_bobcat), 0.3274646, tolerance = 0.000001) +}) + +test_that("Multinomial spline apply works - tune_results", { + skip_if_not_installed("modeldata") + + tct <- testthat_cal_multiclass() + tl_gam <- cal_estimate_multinomial(tct) + tap_gam <- cal_apply(tct, tl_gam) + + expect_equal( + testthat_cal_multiclass_count(), + nrow(tap_gam) + ) +}) diff --git a/tests/testthat/test-cal-apply-regression.R b/tests/testthat/test-cal-apply-regression.R new file mode 100644 index 0000000..132099f --- /dev/null +++ b/tests/testthat/test-cal-apply-regression.R @@ -0,0 +1,37 @@ +test_that("Linear apply works - data.frame", { + sl_linear <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) + ap_linear <- cal_apply(boosting_predictions_oob, sl_linear) + + pred <- ap_linear$.pred + expect_equal(mean(pred), 14.87123, tolerance = 0.000001) + expect_equal(sd(pred), 14.94483, tolerance = 0.000001) +}) + +test_that("Linear apply works - tune_results", { + tct <- testthat_cal_reg() + tl_linear <- cal_estimate_linear(tct, smooth = FALSE) + tap_linear <- cal_apply(tct, tl_linear) + expect_equal( + testthat_cal_reg_count(), + nrow(tap_linear) + ) +}) + +test_that("Linear spline apply works", { + sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome) + ap_gam <- cal_apply(boosting_predictions_oob, sl_gam) + + pred <- ap_gam$.pred + expect_equal(mean(pred), 14.87123, tolerance = 0.000001) + expect_equal(sd(pred), 15.00711, tolerance = 0.000001) +}) + +test_that("Linear spline apply works - tune_results", { + tct <- testthat_cal_reg() + tl_gam <- cal_estimate_linear(tct) + tap_gam <- cal_apply(tct, tl_gam) + expect_equal( + testthat_cal_reg_count(), + nrow(tap_gam) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-cal-apply.R b/tests/testthat/test-cal-apply.R index 790c529..501bdfb 100644 --- a/tests/testthat/test-cal-apply.R +++ b/tests/testthat/test-cal-apply.R @@ -1,87 +1,3 @@ -test_that("Logistic apply works - data.frame", { - sl_logistic <- cal_estimate_logistic(segment_logistic, Class, smooth = FALSE) - ap_logistic <- cal_apply(segment_logistic, sl_logistic) - - pred_good <- ap_logistic$.pred_good - expect_equal(mean(pred_good), 0.3425743, tolerance = 0.000001) - expect_equal(sd(pred_good), 0.2993934, tolerance = 0.000001) -}) - -test_that("Logistic apply works - tune_results", { - skip_if_not_installed("modeldata") - - tct <- testthat_cal_binary() - tl_logistic <- cal_estimate_logistic(tct, smooth = FALSE) - tap_logistic <- cal_apply(tct, tl_logistic) - expect_equal( - testthat_cal_binary_count(), - nrow(tap_logistic) - ) -}) - -test_that("Logistic spline apply works", { - sl_gam <- cal_estimate_logistic(segment_logistic, Class) - ap_gam <- cal_apply(segment_logistic, sl_gam) - - pred_good <- ap_gam$.pred_good - expect_equal(mean(pred_good), 0.3425743, tolerance = 0.000001) - expect_equal(sd(pred_good), 0.2987027, tolerance = 0.000001) -}) - -test_that("Logistic spline apply works - tune_results", { - skip_if_not_installed("modeldata") - - tct <- testthat_cal_binary() - tl_gam <- cal_estimate_logistic(tct) - tap_gam <- cal_apply(tct, tl_gam) - expect_equal( - testthat_cal_binary_count(), - nrow(tap_gam) - ) -}) - -# ------------------------------------------------------------------------------ - -test_that("Linear apply works - data.frame", { - sl_linear <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) - ap_linear <- cal_apply(boosting_predictions_oob, sl_linear) - - pred <- ap_linear$.pred - expect_equal(mean(pred), 14.87123, tolerance = 0.000001) - expect_equal(sd(pred), 14.94483, tolerance = 0.000001) -}) - -test_that("Linear apply works - tune_results", { - tct <- testthat_cal_reg() - tl_linear <- cal_estimate_linear(tct, smooth = FALSE) - tap_linear <- cal_apply(tct, tl_linear) - expect_equal( - testthat_cal_reg_count(), - nrow(tap_linear) - ) -}) - -test_that("Linear spline apply works", { - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome) - ap_gam <- cal_apply(boosting_predictions_oob, sl_gam) - - pred <- ap_gam$.pred - expect_equal(mean(pred), 14.87123, tolerance = 0.000001) - expect_equal(sd(pred), 15.00711, tolerance = 0.000001) -}) - -test_that("Linear spline apply works - tune_results", { - tct <- testthat_cal_reg() - tl_gam <- cal_estimate_linear(tct) - tap_gam <- cal_apply(tct, tl_gam) - expect_equal( - testthat_cal_reg_count(), - nrow(tap_gam) - ) -}) - -# ------------------------------------------------------------------------------ - test_that("Isotonic apply works - data.frame", { set.seed(100) From 5b54948e5328cc3083b02e9e2f52ab193f14d87d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 8 Oct 2025 11:11:50 -0700 Subject: [PATCH 2/6] renaming files to match test files --- DESCRIPTION | 4 ++-- R/{cal-estimate-multinom.R => cal-estimate-multinomial.R} | 0 man/cal_estimate_multinomial.Rd | 4 ++-- man/int_conformal_cv.Rd | 2 +- man/int_conformal_quantile.Rd | 2 +- man/int_conformal_split.Rd | 2 +- man/probably-package.Rd | 2 +- man/required_pkgs.cal_object.Rd | 4 ++-- .../_snaps/{bound-prediction.md => bound_prediction.md} | 0 .../{test-bound-prediction.R => test-bound_prediction.R} | 0 10 files changed, 10 insertions(+), 10 deletions(-) rename R/{cal-estimate-multinom.R => cal-estimate-multinomial.R} (100%) rename tests/testthat/_snaps/{bound-prediction.md => bound_prediction.md} (100%) rename tests/testthat/{test-bound-prediction.R => test-bound_prediction.R} (100%) diff --git a/DESCRIPTION b/DESCRIPTION index 08959d6..0709eaa 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -60,7 +60,7 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 Collate: 'bound_prediction.R' 'cal-apply-binary.R' @@ -72,7 +72,7 @@ Collate: 'cal-estimate-isotonic.R' 'cal-estimate-linear.R' 'cal-estimate-logistic.R' - 'cal-estimate-multinom.R' + 'cal-estimate-multinomial.R' 'cal-estimate-utils.R' 'cal-estimate-none.R' 'cal-pkg-check.R' diff --git a/R/cal-estimate-multinom.R b/R/cal-estimate-multinomial.R similarity index 100% rename from R/cal-estimate-multinom.R rename to R/cal-estimate-multinomial.R diff --git a/man/cal_estimate_multinomial.Rd b/man/cal_estimate_multinomial.Rd index 4277046..c3f9d35 100644 --- a/man/cal_estimate_multinomial.Rd +++ b/man/cal_estimate_multinomial.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/cal-estimate-multinom.R +% Please edit documentation in R/cal-estimate-multinomial.R \name{cal_estimate_multinomial} \alias{cal_estimate_multinomial} \alias{cal_estimate_multinomial.data.frame} @@ -79,7 +79,7 @@ When \code{smooth = FALSE}, \code{\link[nnet:multinom]{nnet::multinom()}} functi model, otherwise \code{\link[mgcv:gam]{mgcv::gam()}} is used. } \examples{ -\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "randomForest"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "randomForest"))) withAutoprint(\{ # examplesIf} library(modeldata) library(parsnip) library(dplyr) diff --git a/man/int_conformal_cv.Rd b/man/int_conformal_cv.Rd index dc5ff5c..1448e68 100644 --- a/man/int_conformal_cv.Rd +++ b/man/int_conformal_cv.Rd @@ -51,7 +51,7 @@ stop the computations for other types of resamples, but we have no way of knowing whether the results are appropriate. } \examples{ -\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip"))) withAutoprint(\{ # examplesIf} library(workflows) library(dplyr) library(parsnip) diff --git a/man/int_conformal_quantile.Rd b/man/int_conformal_quantile.Rd index 1d311c9..1d47177 100644 --- a/man/int_conformal_quantile.Rd +++ b/man/int_conformal_quantile.Rd @@ -46,7 +46,7 @@ Note that the because of the method used to construct the interval, it is possible that the prediction intervals will not include the predicted value. } \examples{ -\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "quantregForest"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "quantregForest"))) withAutoprint(\{ # examplesIf} library(workflows) library(dplyr) library(parsnip) diff --git a/man/int_conformal_split.Rd b/man/int_conformal_split.Rd index c65ee67..6a94e8f 100644 --- a/man/int_conformal_split.Rd +++ b/man/int_conformal_split.Rd @@ -45,7 +45,7 @@ quantile (e.g., the 95th for 95\% interval) and should not include rows that were in the original training set. } \examples{ -\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "nnet"))) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!probably:::is_cran_check() & rlang::is_installed(c("modeldata", "parsnip", "nnet"))) withAutoprint(\{ # examplesIf} library(workflows) library(dplyr) library(parsnip) diff --git a/man/probably-package.Rd b/man/probably-package.Rd index 94442be..59b829b 100644 --- a/man/probably-package.Rd +++ b/man/probably-package.Rd @@ -30,7 +30,7 @@ Authors: Other contributors: \itemize{ - \item Posit Software, PBC (03wc8by49) [copyright holder, funder] + \item Posit Software, PBC (\href{https://ror.org/03wc8by49}{ROR}) [copyright holder, funder] } } diff --git a/man/required_pkgs.cal_object.Rd b/man/required_pkgs.cal_object.Rd index af3e2bc..d0b919a 100644 --- a/man/required_pkgs.cal_object.Rd +++ b/man/required_pkgs.cal_object.Rd @@ -1,7 +1,7 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/cal-estimate-beta.R, R/cal-estimate-linear.R, -% R/cal-estimate-logistic.R, R/cal-estimate-multinom.R, R/cal-estimate-none.R, -% R/cal-pkg-check.R +% R/cal-estimate-logistic.R, R/cal-estimate-multinomial.R, +% R/cal-estimate-none.R, R/cal-pkg-check.R \name{required_pkgs.cal_estimate_beta} \alias{required_pkgs.cal_estimate_beta} \alias{required_pkgs.cal_estimate_linear_spline} diff --git a/tests/testthat/_snaps/bound-prediction.md b/tests/testthat/_snaps/bound_prediction.md similarity index 100% rename from tests/testthat/_snaps/bound-prediction.md rename to tests/testthat/_snaps/bound_prediction.md diff --git a/tests/testthat/test-bound-prediction.R b/tests/testthat/test-bound_prediction.R similarity index 100% rename from tests/testthat/test-bound-prediction.R rename to tests/testthat/test-bound_prediction.R From 39c6bf355a4436180c764fb052e6b73283c95561 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 8 Oct 2025 11:22:49 -0700 Subject: [PATCH 3/6] delete duplicate tests --- .../testthat/_snaps/cal-estimate-isotonic.md | 17 + tests/testthat/_snaps/cal-estimate-linear.md | 35 - tests/testthat/_snaps/cal-estimate.md | 709 ------------------ tests/testthat/test-cal-estimate-beta.R | 4 + tests/testthat/test-cal-estimate-isotonic.R | 34 + tests/testthat/test-cal-estimate-linear.R | 37 +- tests/testthat/test-cal-estimate-logistic.R | 14 +- .../testthat/test-cal-estimate-multinomial.R | 8 + tests/testthat/test-cal-estimate.R | 681 ----------------- 9 files changed, 84 insertions(+), 1455 deletions(-) delete mode 100644 tests/testthat/_snaps/cal-estimate.md delete mode 100644 tests/testthat/test-cal-estimate.R diff --git a/tests/testthat/_snaps/cal-estimate-isotonic.md b/tests/testthat/_snaps/cal-estimate-isotonic.md index 66ff4cc..9caf49f 100644 --- a/tests/testthat/_snaps/cal-estimate-isotonic.md +++ b/tests/testthat/_snaps/cal-estimate-isotonic.md @@ -191,3 +191,20 @@ x This function does not work with grouped data frames. i Apply `dplyr::ungroup()` and use the `.by` argument. +# Non-default names used for estimate columns + + Code + cal_estimate_isotonic(new_segment, Class, c(good, poor)) + Message + + -- Probability Calibration + Method: Isotonic regression calibration + Type: Binary + Source class: Data Frame + Data points: 1,010 + Unique Predicted Values: 78 + Truth variable: `Class` + Estimate variables: + `good` ==> good + `poor` ==> poor + diff --git a/tests/testthat/_snaps/cal-estimate-linear.md b/tests/testthat/_snaps/cal-estimate-linear.md index 74104ad..48d99f9 100644 --- a/tests/testthat/_snaps/cal-estimate-linear.md +++ b/tests/testthat/_snaps/cal-estimate-linear.md @@ -128,38 +128,3 @@ Warning: Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. ---- - - Code - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - ---- - - Code - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, - smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - diff --git a/tests/testthat/_snaps/cal-estimate.md b/tests/testthat/_snaps/cal-estimate.md deleted file mode 100644 index 396d78b..0000000 --- a/tests/testthat/_snaps/cal-estimate.md +++ /dev/null @@ -1,709 +0,0 @@ -# Logistic estimates work - data.frame - - Code - print(sl_logistic) - Message - - -- Probability Calibration - Method: Logistic regression calibration - Type: Binary - Source class: Data Frame - Data points: 1,010 - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - The selectors in `estimate` resolves to 1 values (".pred_poor") but there are 2 class levels ("good" and "poor"). - ---- - - The `truth` column has 4 levels ("VF", "F", "M", and "L"), but only two-class factors are allowed for this calibration method. - ---- - - Code - print(sl_logistic_group) - Message - - -- Probability Calibration - Method: Logistic regression calibration - Type: Binary - Source class: Data Frame - Data points: 1,010, split in 2 groups - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Logistic estimates work - tune_results - - Code - print(tl_logistic) - Message - - -- Probability Calibration - Method: Logistic regression calibration - Type: Binary - Source class: Tune Results - Data points: 4,000, split in 8 groups - Truth variable: `class` - Estimate variables: - `.pred_class_1` ==> class_1 - `.pred_class_2` ==> class_2 - ---- - - The `truth` column has 3 levels ("one", "two", and "three"), but only two-class factors are allowed for this calibration method. - -# Logistic estimates errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Logistic spline estimates work - data.frame - - Code - print(sl_gam) - Message - - -- Probability Calibration - Method: Generalized additive model calibration - Type: Binary - Source class: Data Frame - Data points: 1,010 - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - Code - print(sl_gam_group) - Message - - -- Probability Calibration - Method: Generalized additive model calibration - Type: Binary - Source class: Data Frame - Data points: 1,010, split in 2 groups - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Logistic spline estimates work - tune_results - - Code - print(tl_gam) - Message - - -- Probability Calibration - Method: Generalized additive model calibration - Type: Binary - Source class: Tune Results - Data points: 4,000, split in 8 groups - Truth variable: `class` - Estimate variables: - `.pred_class_1` ==> class_1 - `.pred_class_2` ==> class_2 - -# Logistic spline switches to linear if too few unique - - Code - sl_gam <- cal_estimate_logistic(segment_logistic, Class, smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - ---- - - Code - sl_gam <- cal_estimate_logistic(segment_logistic, Class, .by = id, smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - -# Isotonic estimates work - data.frame - - Code - print(sl_isotonic) - Message - - -- Probability Calibration - Method: Isotonic regression calibration - Type: Binary - Source class: Data Frame - Data points: 1,010 - Unique Predicted Values: 78 - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - Code - print(sl_isotonic_group) - Message - - -- Probability Calibration - Method: Isotonic regression calibration - Type: Binary - Source class: Data Frame - Data points: 1,010, split in 2 groups - Unique Predicted Values: 77 - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Isotonic estimates work - tune_results - - Code - print(tl_isotonic) - Message - - -- Probability Calibration - Method: Isotonic regression calibration - Type: Binary - Source class: Tune Results - Data points: 4,000, split in 8 groups - Unique Predicted Values: 92 - Truth variable: `class` - Estimate variables: - `.pred_class_1` ==> class_1 - `.pred_class_2` ==> class_2 - ---- - - Code - print(mtnl_isotonic) - Message - - -- Probability Calibration - Method: Isotonic regression calibration - Type: Multiclass (1 v All) - Source class: Tune Results - Data points: 5,000, split in 10 groups - Truth variable: `class` - Estimate variables: - `.pred_one` ==> one - `.pred_two` ==> two - `.pred_three` ==> three - -# Isotonic estimates errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Isotonic linear estimates work - data.frame - - Code - print(sl_logistic) - Message - - -- Probability Calibration - Method: Isotonic regression calibration - Type: Regression - Source class: Data Frame - Data points: 2,000 - Unique Predicted Values: 47 - Truth variable: `outcome` - Estimate variables: - `.pred` ==> predictions - ---- - - Code - print(sl_logistic_group) - Message - - -- Probability Calibration - Method: Isotonic regression calibration - Type: Regression - Source class: Data Frame - Data points: 2,000, split in 10 groups - Unique Predicted Values: 18 - Truth variable: `outcome` - Estimate variables: - `.pred` ==> predictions - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Isotonic Bootstrapped estimates work - data.frame - - Code - print(sl_boot) - Message - - -- Probability Calibration - Method: Bootstrapped isotonic regression calibration - Type: Binary - Source class: Data Frame - Data points: 1,010 - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - Code - print(sl_boot_group) - Message - - -- Probability Calibration - Method: Bootstrapped isotonic regression calibration - Type: Binary - Source class: Data Frame - Data points: 1,010, split in 2 groups - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Isotonic Bootstrapped estimates work - tune_results - - Code - print(tl_isotonic) - Message - - -- Probability Calibration - Method: Bootstrapped isotonic regression calibration - Type: Binary - Source class: Tune Results - Data points: 4,000, split in 8 groups - Truth variable: `class` - Estimate variables: - `.pred_class_1` ==> class_1 - `.pred_class_2` ==> class_2 - ---- - - Code - print(mtnl_isotonic) - Message - - -- Probability Calibration - Method: Bootstrapped isotonic regression calibration - Type: Multiclass (1 v All) - Source class: Tune Results - Data points: 5,000, split in 10 groups - Truth variable: `class` - Estimate variables: - `.pred_one` ==> one - `.pred_two` ==> two - `.pred_three` ==> three - -# Isotonic Bootstrapped estimates errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Beta estimates work - data.frame - - Code - print(sl_beta) - Message - - -- Probability Calibration - Method: Beta calibration - Type: Binary - Source class: Data Frame - Data points: 1,010 - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - Code - print(sl_beta_group) - Message - - -- Probability Calibration - Method: Beta calibration - Type: Binary - Source class: Data Frame - Data points: 1,010, split in 2 groups - Truth variable: `Class` - Estimate variables: - `.pred_good` ==> good - `.pred_poor` ==> poor - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Beta estimates work - tune_results - - Code - print(tl_beta) - Message - - -- Probability Calibration - Method: Beta calibration - Type: Binary - Source class: Tune Results - Data points: 4,000, split in 8 groups - Truth variable: `class` - Estimate variables: - `.pred_class_1` ==> class_1 - `.pred_class_2` ==> class_2 - ---- - - Code - print(mtnl_beta) - Message - - -- Probability Calibration - Method: Beta calibration - Type: Multiclass (1 v All) - Source class: Tune Results - Data points: 5,000, split in 10 groups - Truth variable: `class` - Estimate variables: - `.pred_one` ==> one - `.pred_two` ==> two - `.pred_three` ==> three - -# Beta estimates errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Multinomial estimates work - data.frame - - Code - print(sp_multi) - Message - - -- Probability Calibration - Method: Multinomial regression calibration - Type: Multiclass - Source class: Data Frame - Data points: 110 - Truth variable: `Species` - Estimate variables: - `.pred_bobcat` ==> bobcat - `.pred_coyote` ==> coyote - `.pred_gray_fox` ==> gray_fox - ---- - - Code - print(sp_smth_multi) - Message - - -- Probability Calibration - Method: Generalized additive model calibration - Type: Multiclass - Source class: Data Frame - Data points: 110 - Truth variable: `Species` - Estimate variables: - `.pred_bobcat` ==> bobcat - `.pred_coyote` ==> coyote - `.pred_gray_fox` ==> gray_fox - ---- - - Code - print(sl_multi_group) - Message - - -- Probability Calibration - Method: Multinomial regression calibration - Type: Multiclass - Source class: Data Frame - Data points: 110, split in 2 groups - Truth variable: `Species` - Estimate variables: - `.pred_bobcat` ==> bobcat - `.pred_coyote` ==> coyote - `.pred_gray_fox` ==> gray_fox - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Multinomial estimates work - tune_results - - Code - print(tl_multi) - Message - - -- Probability Calibration - Method: Multinomial regression calibration - Type: Multiclass - Source class: Tune Results - Data points: 5,000, split in 10 groups - Truth variable: `class` - Estimate variables: - `.pred_one` ==> one - `.pred_two` ==> two - `.pred_three` ==> three - ---- - - Code - print(tl_smth_multi) - Message - - -- Probability Calibration - Method: Generalized additive model calibration - Type: Multiclass - Source class: Tune Results - Data points: 5,000, split in 10 groups - Truth variable: `class` - Estimate variables: - `.pred_one` ==> one - `.pred_two` ==> two - `.pred_three` ==> three - -# Multinomial estimates errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Linear spline switches to linear if too few unique - - Code - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - ---- - - Code - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, - smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - ---- - - Code - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - ---- - - Code - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, - smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - -# Multinomial spline switches to linear if too few unique - - Code - sl_gam <- cal_estimate_multinomial(smol_species_probs, Species, smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - ---- - - Code - sl_gam <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, - smooth = TRUE) - Condition - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - Warning: - Too few unique observations for spline-based calibrator. Setting `smooth = FALSE`. - -# Linear estimates work - data.frame - - Code - print(sl_linear) - Message - - -- Regression Calibration - Method: Linear calibration - Source class: Data Frame - Data points: 2,000 - Truth variable: `outcome` - Estimate variable: `.pred` - ---- - - Code - print(sl_linear_group) - Message - - -- Regression Calibration - Method: Linear calibration - Source class: Data Frame - Data points: 2,000, split in 2 groups - Truth variable: `outcome` - Estimate variable: `.pred` - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Linear estimates work - tune_results - - Code - print(tl_linear) - Message - - -- Regression Calibration - Method: Linear calibration - Source class: Tune Results - Data points: 750, split in 10 groups - Truth variable: `outcome` - Estimate variable: `.pred` - -# Linear estimates errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Linear spline estimates work - data.frame - - Code - print(sl_gam) - Message - - -- Regression Calibration - Method: Generalized additive model calibration - Source class: Data Frame - Data points: 2,000 - Truth variable: `outcome` - Estimate variable: `.pred` - ---- - - Code - print(sl_gam_group) - Message - - -- Regression Calibration - Method: Generalized additive model calibration - Source class: Data Frame - Data points: 2,000, split in 2 groups - Truth variable: `outcome` - Estimate variable: `.pred` - ---- - - x `.by` cannot select more than one column. - i The following columns were selected: - i group1 and group2 - -# Linear spline estimates work - tune_results - - Code - print(tl_gam) - Message - - -- Regression Calibration - Method: Generalized additive model calibration - Source class: Tune Results - Data points: 750, split in 10 groups - Truth variable: `outcome` - Estimate variable: `.pred` - -# Non-default names used for estimate columns - - Code - cal_estimate_isotonic(new_segment, Class, c(good, poor)) - Message - - -- Probability Calibration - Method: Isotonic regression calibration - Type: Binary - Source class: Data Frame - Data points: 1,010 - Unique Predicted Values: 78 - Truth variable: `Class` - Estimate variables: - `good` ==> good - `poor` ==> poor - diff --git a/tests/testthat/test-cal-estimate-beta.R b/tests/testthat/test-cal-estimate-beta.R index 19a384b..411819f 100644 --- a/tests/testthat/test-cal-estimate-beta.R +++ b/tests/testthat/test-cal-estimate-beta.R @@ -5,6 +5,10 @@ test_that("Beta estimates work - data.frame", { expect_cal_method(sl_beta, "Beta calibration") expect_cal_rows(sl_beta) expect_snapshot(print(sl_beta)) + expect_equal( + required_pkgs(sl_beta), + c("betacal", "probably") + ) sl_beta_group <- segment_logistic |> dplyr::mutate(group = .pred_poor > 0.5) |> diff --git a/tests/testthat/test-cal-estimate-isotonic.R b/tests/testthat/test-cal-estimate-isotonic.R index 443c42b..f70b489 100644 --- a/tests/testthat/test-cal-estimate-isotonic.R +++ b/tests/testthat/test-cal-estimate-isotonic.R @@ -143,8 +143,42 @@ test_that("Isotonic Bootstrapped estimates work - tune_results", { ) }) +# ----------------------------------- Other ------------------------------------ test_that("Isotonic Bootstrapped estimates errors - grouped_df", { expect_snapshot_error( cal_estimate_isotonic_boot(dplyr::group_by(mtcars, vs)) ) }) + +test_that("Non-default names used for estimate columns", { + skip_if_not_installed("modeldata") + + new_segment <- segment_logistic + colnames(new_segment) <- c("poor", "good", "Class") + + set.seed(100) + expect_snapshot( + cal_estimate_isotonic(new_segment, Class, c(good, poor)) + ) +}) + +test_that("Test exceptions", { + expect_error( + cal_estimate_isotonic(segment_logistic, Class, dplyr::starts_with("bad")) + ) +}) + +test_that("non-standard column names", { + library(dplyr) + # issue 145 + seg <- segment_logistic |> + rename_with(~ paste0(.x, "-1"), matches(".pred")) |> + mutate( + Class = paste0(Class,"-1"), + Class = factor(Class), + .pred_class = ifelse(`.pred_poor-1` >= 0.5, "poor-1", "good-1") + ) + calib <- cal_estimate_isotonic(seg, Class) + new_pred <- cal_apply(seg, calib, pred_class = .pred_class) + expect_named(new_pred, c(".pred_poor-1", ".pred_good-1", "Class", ".pred_class")) +}) diff --git a/tests/testthat/test-cal-estimate-linear.R b/tests/testthat/test-cal-estimate-linear.R index 3425477..dd2a6e2 100644 --- a/tests/testthat/test-cal-estimate-linear.R +++ b/tests/testthat/test-cal-estimate-linear.R @@ -7,6 +7,10 @@ test_that("Linear estimates work - data.frame", { expect_cal_estimate(sl_linear, "butchered_glm") expect_cal_rows(sl_linear, 2000) expect_snapshot(print(sl_linear)) + expect_equal( + required_pkgs(sl_linear), + c("probably") + ) sl_linear_group <- boosting_predictions_oob |> dplyr::mutate(group = .pred > 0.5) |> @@ -51,6 +55,10 @@ test_that("Linear spline estimates work - data.frame", { expect_cal_estimate(sl_gam, "butchered_gam") expect_cal_rows(sl_gam, 2000) expect_snapshot(print(sl_gam)) + expect_equal( + required_pkgs(sl_gam), + c("mgcv", "probably") + ) sl_gam_group <- boosting_predictions_oob |> dplyr::mutate(group = .pred > 0.5) |> @@ -111,32 +119,3 @@ test_that("Linear spline switches to linear if too few unique", { ) }) -test_that("Linear spline switches to linear if too few unique", { - skip_if_not_installed("modeldata") - - boosting_predictions_oob$.pred <- rep( - x = 1:5, - length.out = nrow(boosting_predictions_oob) - ) - - expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = TRUE) - ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) - - expect_identical( - sl_gam$estimate, - sl_lm$estimate - ) - - expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = TRUE) - ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = FALSE) - - expect_identical( - sl_gam$estimate, - sl_lm$estimate - ) -}) - diff --git a/tests/testthat/test-cal-estimate-logistic.R b/tests/testthat/test-cal-estimate-logistic.R index bfa9ac6..39e0bb0 100644 --- a/tests/testthat/test-cal-estimate-logistic.R +++ b/tests/testthat/test-cal-estimate-logistic.R @@ -26,6 +26,11 @@ test_that("Logistic estimates work - data.frame", { expect_cal_estimate(sl_logistic_group, "butchered_glm") expect_cal_rows(sl_logistic_group) expect_snapshot(print(sl_logistic_group)) + expect_equal( + required_pkgs(sl_logistic_group), + "probably" + ) + expect_snapshot_error( segment_logistic |> @@ -42,7 +47,10 @@ test_that("Logistic estimates work - data.frame", { two_cls_res <- cal_apply(two_class_example, two_cls_mod, pred_class = predicted) expect_equal(two_cls_res[0,], two_cls_plist) - + expect_equal( + required_pkgs(two_cls_mod), + c("mgcv", "probably") + ) }) test_that("Logistic estimates work - tune_results", { @@ -99,6 +107,10 @@ test_that("Logistic spline estimates work - tune_results", { expect_cal_method(tl_gam, "Generalized additive model calibration") expect_cal_estimate(tl_gam, "butchered_gam") expect_snapshot(print(tl_gam)) + expect_equal( + required_pkgs(tl_gam), + c("mgcv", "probably") + ) expect_equal( testthat_cal_binary_count(), diff --git a/tests/testthat/test-cal-estimate-multinomial.R b/tests/testthat/test-cal-estimate-multinomial.R index d6ae214..5228f22 100644 --- a/tests/testthat/test-cal-estimate-multinomial.R +++ b/tests/testthat/test-cal-estimate-multinomial.R @@ -7,12 +7,20 @@ test_that("Multinomial estimates work - data.frame", { expect_cal_method(sp_multi, "Multinomial regression calibration") expect_cal_rows(sp_multi, n = 110) expect_snapshot(print(sp_multi)) + expect_equal( + required_pkgs(sp_multi), + c("nnet", "probably") + ) sp_smth_multi <- cal_estimate_multinomial(species_probs, Species, smooth = TRUE) expect_cal_type(sp_smth_multi, "multiclass") expect_cal_method(sp_smth_multi, "Generalized additive model calibration") expect_cal_rows(sp_smth_multi, n = 110) expect_snapshot(print(sp_smth_multi)) + expect_equal( + required_pkgs(sp_smth_multi), + c("mgcv", "probably") + ) sl_multi_group <- species_probs |> dplyr::mutate(group = .pred_bobcat > 0.5) |> diff --git a/tests/testthat/test-cal-estimate.R b/tests/testthat/test-cal-estimate.R deleted file mode 100644 index 67f65fd..0000000 --- a/tests/testthat/test-cal-estimate.R +++ /dev/null @@ -1,681 +0,0 @@ -# --------------------------------- Logistic ----------------------------------- -test_that("Logistic estimates work - data.frame", { - skip_if_not_installed("modeldata") - - sl_logistic <- cal_estimate_logistic(segment_logistic, Class, smooth = FALSE) - expect_cal_type(sl_logistic, "binary") - expect_cal_method(sl_logistic, "Logistic regression calibration") - expect_cal_estimate(sl_logistic, "butchered_glm") - expect_cal_rows(sl_logistic) - expect_snapshot(print(sl_logistic)) - - expect_snapshot_error( - segment_logistic |> cal_estimate_logistic(truth = Class, estimate = .pred_poor) - ) - - data(hpc_cv, package = "yardstick") - expect_snapshot_error( - modeldata::hpc_cv |> cal_estimate_logistic(truth = obs, estimate = c(VF:L)) - ) - - sl_logistic_group <- segment_logistic |> - dplyr::mutate(group = .pred_poor > 0.5) |> - cal_estimate_logistic(Class, .by = group, smooth = FALSE) - - expect_cal_type(sl_logistic_group, "binary") - expect_cal_method(sl_logistic_group, "Logistic regression calibration") - expect_cal_estimate(sl_logistic_group, "butchered_glm") - expect_cal_rows(sl_logistic_group) - expect_snapshot(print(sl_logistic_group)) - expect_equal( - required_pkgs(sl_logistic_group), - "probably" - ) - - expect_snapshot_error( - segment_logistic |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_logistic(Class, .by = c(group1, group2), smooth = FALSE) - ) - - # ------------------------------------------------------------------------------ - - data(two_class_example, package = "modeldata") - two_cls_plist <- two_class_example[0,] - two_cls_mod <- - cal_estimate_logistic(two_class_example, truth = truth, estimate = c(Class1, Class2)) - - two_cls_res <- cal_apply(two_class_example, two_cls_mod, pred_class = predicted) - expect_equal(two_cls_res[0,], two_cls_plist) - expect_equal( - required_pkgs(two_cls_mod), - c("mgcv", "probably") - ) - -}) - -test_that("Logistic estimates work - tune_results", { - skip_if_not_installed("modeldata") - - tl_logistic <- cal_estimate_logistic(testthat_cal_binary(), smooth = FALSE) - expect_cal_type(tl_logistic, "binary") - expect_cal_method(tl_logistic, "Logistic regression calibration") - expect_cal_estimate(tl_logistic, "butchered_glm") - expect_snapshot(print(tl_logistic)) - - expect_snapshot_error( - cal_estimate_logistic(testthat_cal_multiclass(), smooth = FALSE) - ) -}) - -test_that("Logistic estimates errors - grouped_df", { - expect_snapshot_error( - cal_estimate_logistic(dplyr::group_by(mtcars, vs), smooth = FALSE) - ) -}) - -# ----------------------------- Logistic Spline -------------------------------- -test_that("Logistic spline estimates work - data.frame", { - sl_gam <- cal_estimate_logistic(segment_logistic, Class) - expect_cal_type(sl_gam, "binary") - expect_cal_method(sl_gam, "Generalized additive model calibration") - expect_cal_estimate(sl_gam, "butchered_gam") - expect_cal_rows(sl_gam) - expect_snapshot(print(sl_gam)) - - sl_gam_group <- segment_logistic |> - dplyr::mutate(group = .pred_poor > 0.5) |> - cal_estimate_logistic(Class, .by = group) - - expect_cal_type(sl_gam_group, "binary") - expect_cal_method(sl_gam_group, "Generalized additive model calibration") - expect_cal_estimate(sl_gam_group, "butchered_gam") - expect_cal_rows(sl_gam_group) - expect_snapshot(print(sl_gam_group)) - - expect_snapshot_error( - segment_logistic |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_logistic(Class, .by = c(group1, group2)) - ) -}) - -test_that("Logistic spline estimates work - tune_results", { - skip_if_not_installed("modeldata") - - tl_gam <- cal_estimate_logistic(testthat_cal_binary()) - expect_cal_type(tl_gam, "binary") - expect_cal_method(tl_gam, "Generalized additive model calibration") - expect_cal_estimate(tl_gam, "butchered_gam") - expect_snapshot(print(tl_gam)) - expect_equal( - required_pkgs(tl_gam), - c("mgcv", "probably") - ) - - expect_equal( - testthat_cal_binary_count(), - nrow(cal_apply(testthat_cal_binary(), tl_gam)) - ) -}) - -test_that("Logistic spline switches to linear if too few unique", { - skip_if_not_installed("modeldata") - - segment_logistic$.pred_good <- rep( - x = 1, - length.out = nrow(segment_logistic) - ) - - expect_snapshot( - sl_gam <- cal_estimate_logistic(segment_logistic, Class, smooth = TRUE) - ) - sl_lm <- cal_estimate_logistic(segment_logistic, Class, smooth = FALSE) - - expect_identical( - sl_gam$estimates[[1]]$estimate[[1]], - sl_lm$estimates[[1]]$estimate[[1]] - ) - - segment_logistic$id <- rep( - x = 1:2, - length.out = nrow(segment_logistic) - ) - expect_snapshot( - sl_gam <- cal_estimate_logistic(segment_logistic, Class, .by = id, smooth = TRUE) - ) - sl_lm <- cal_estimate_logistic(segment_logistic, Class, .by = id, smooth = FALSE) - - expect_identical( - sl_gam$estimates[[1]]$estimate[[1]], - sl_lm$estimates[[1]]$estimate[[1]] - ) -}) - -# --------------------------------- Isotonic ----------------------------------- -test_that("Isotonic estimates work - data.frame", { - skip_if_not_installed("modeldata") - - set.seed(100) - sl_isotonic <- cal_estimate_isotonic(segment_logistic, Class) - expect_cal_type(sl_isotonic, "binary") - expect_cal_method(sl_isotonic, "Isotonic regression calibration") - expect_cal_rows(sl_isotonic) - expect_snapshot(print(sl_isotonic)) - - set.seed(100) - sl_isotonic_group <- segment_logistic |> - dplyr::mutate(group = .pred_poor > 0.5) |> - cal_estimate_isotonic(Class, .by = group) - - expect_cal_type(sl_isotonic_group, "binary") - expect_cal_method(sl_isotonic_group, "Isotonic regression calibration") - expect_cal_rows(sl_isotonic_group) - expect_snapshot(print(sl_isotonic_group)) - - set.seed(100) - expect_snapshot_error( - segment_logistic |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_isotonic(Class, .by = c(group1, group2)) - ) - -}) - -test_that("Isotonic estimates work - tune_results", { - skip_if_not_installed("modeldata") - - set.seed(100) - tl_isotonic <- cal_estimate_isotonic(testthat_cal_binary()) - expect_cal_type(tl_isotonic, "binary") - expect_cal_method(tl_isotonic, "Isotonic regression calibration") - expect_snapshot(print(tl_isotonic)) - - expect_equal( - testthat_cal_binary_count(), - nrow(cal_apply(testthat_cal_binary(), tl_isotonic)) - ) - - # ------------------------------------------------------------------------------ - # multinomial outcomes - - set.seed(100) - mtnl_isotonic <- cal_estimate_isotonic(testthat_cal_multiclass()) - expect_cal_type(mtnl_isotonic, "one_vs_all") - expect_cal_method(mtnl_isotonic, "Isotonic regression calibration") - expect_snapshot(print(mtnl_isotonic)) - - expect_equal( - testthat_cal_multiclass_count(), - nrow(cal_apply(testthat_cal_multiclass(), mtnl_isotonic)) - ) -}) - -test_that("Isotonic estimates errors - grouped_df", { - expect_snapshot_error( - cal_estimate_isotonic(dplyr::group_by(mtcars, vs)) - ) -}) - -test_that("Isotonic linear estimates work - data.frame", { - skip_if_not_installed("modeldata") - - set.seed(2983) - sl_logistic <- cal_estimate_isotonic(boosting_predictions_oob, outcome, estimate = .pred) - expect_cal_type(sl_logistic, "regression") - expect_cal_method(sl_logistic, "Isotonic regression calibration") - expect_cal_rows(sl_logistic, 2000) - expect_snapshot(print(sl_logistic)) - - set.seed(38) - sl_logistic_group <- boosting_predictions_oob |> - cal_estimate_isotonic(outcome, estimate = .pred, .by = id) - - expect_cal_type(sl_logistic_group, "regression") - expect_cal_method(sl_logistic_group, "Isotonic regression calibration") - expect_cal_rows(sl_logistic_group, 2000) - expect_snapshot(print(sl_logistic_group)) - - expect_snapshot_error( - boosting_predictions_oob |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_isotonic(outcome, estimate = .pred, .by = c(group1, group2)) - ) -}) - -# -------------------------- Isotonic Bootstrapped ----------------------------- -test_that("Isotonic Bootstrapped estimates work - data.frame", { - skip_if_not_installed("modeldata") - - set.seed(1) - sl_boot <- cal_estimate_isotonic_boot(segment_logistic, Class) - expect_cal_type(sl_boot, "binary") - expect_cal_method(sl_boot, "Bootstrapped isotonic regression calibration") - expect_snapshot(print(sl_boot)) - - sl_boot_group <- segment_logistic |> - dplyr::mutate(group = .pred_poor > 0.5) |> - cal_estimate_isotonic_boot(Class, .by = group) - - expect_cal_type(sl_boot_group, "binary") - expect_cal_method(sl_boot_group, "Bootstrapped isotonic regression calibration") - expect_snapshot(print(sl_boot_group)) - - expect_snapshot_error( - segment_logistic |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_isotonic_boot(Class, .by = c(group1, group2)) - ) - -}) - -test_that("Isotonic Bootstrapped estimates work - tune_results", { - skip_if_not_installed("modeldata") - - set.seed(100) - tl_isotonic <- cal_estimate_isotonic_boot(testthat_cal_binary()) - expect_cal_type(tl_isotonic, "binary") - expect_cal_method(tl_isotonic, "Bootstrapped isotonic regression calibration") - expect_snapshot(print(tl_isotonic)) - - expect_equal( - testthat_cal_binary_count(), - nrow(cal_apply(testthat_cal_binary(), tl_isotonic)) - ) - - # ------------------------------------------------------------------------------ - # multinomial outcomes - - set.seed(100) - mtnl_isotonic <- cal_estimate_isotonic_boot(testthat_cal_multiclass()) - expect_cal_type(mtnl_isotonic, "one_vs_all") - expect_cal_method(mtnl_isotonic, "Bootstrapped isotonic regression calibration") - expect_snapshot(print(mtnl_isotonic)) - - expect_equal( - testthat_cal_multiclass_count(), - nrow(cal_apply(testthat_cal_multiclass(), mtnl_isotonic)) - ) -}) - -test_that("Isotonic Bootstrapped estimates errors - grouped_df", { - expect_snapshot_error( - cal_estimate_isotonic_boot(dplyr::group_by(mtcars, vs)) - ) -}) - -# ----------------------------------- Beta ------------------------------------- -test_that("Beta estimates work - data.frame", { - skip_if_not_installed("betacal") - sl_beta <- cal_estimate_beta(segment_logistic, Class, smooth = FALSE) - expect_cal_type(sl_beta, "binary") - expect_cal_method(sl_beta, "Beta calibration") - expect_cal_rows(sl_beta) - expect_snapshot(print(sl_beta)) - - sl_beta_group <- segment_logistic |> - dplyr::mutate(group = .pred_poor > 0.5) |> - cal_estimate_beta(Class, smooth = FALSE, .by = group) - - expect_cal_type(sl_beta_group, "binary") - expect_cal_method(sl_beta_group, "Beta calibration") - expect_cal_rows(sl_beta_group) - expect_snapshot(print(sl_beta_group)) - - expect_snapshot_error( - segment_logistic |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_beta(Class, smooth = FALSE, .by = c(group1, group2)) - ) - -}) - -test_that("Beta estimates work - tune_results", { - skip_if_not_installed("betacal") - skip_if_not_installed("modeldata") - - tl_beta <- cal_estimate_beta(testthat_cal_binary()) - expect_cal_type(tl_beta, "binary") - expect_cal_method(tl_beta, "Beta calibration") - expect_snapshot(print(tl_beta)) - - expect_equal( - testthat_cal_binary_count(), - nrow(cal_apply(testthat_cal_binary(), tl_beta)) - ) - - # ------------------------------------------------------------------------------ - # multinomial outcomes - - set.seed(100) - suppressWarnings( - mtnl_beta <- cal_estimate_beta(testthat_cal_multiclass()) - ) - expect_cal_type(mtnl_beta, "one_vs_all") - expect_cal_method(mtnl_beta, "Beta calibration") - expect_snapshot(print(mtnl_beta)) - - expect_equal( - testthat_cal_multiclass_count(), - nrow(cal_apply(testthat_cal_multiclass(), mtnl_beta)) - ) -}) - -test_that("Beta estimates errors - grouped_df", { - skip_if_not_installed("betacal") - expect_snapshot_error( - cal_estimate_beta(dplyr::group_by(mtcars, vs)) - ) -}) - -# ------------------------------ Multinomial ----------------------------------- -test_that("Multinomial estimates work - data.frame", { - skip_if_not_installed("modeldata") - skip_if_not_installed("nnet") - - sp_multi <- cal_estimate_multinomial(species_probs, Species, smooth = FALSE) - expect_cal_type(sp_multi, "multiclass") - expect_cal_method(sp_multi, "Multinomial regression calibration") - expect_cal_rows(sp_multi, n = 110) - expect_snapshot(print(sp_multi)) - expect_equal( - required_pkgs(sp_multi), - c("nnet", "probably") - ) - - sp_smth_multi <- cal_estimate_multinomial(species_probs, Species, smooth = TRUE) - expect_cal_type(sp_smth_multi, "multiclass") - expect_cal_method(sp_smth_multi, "Generalized additive model calibration") - expect_cal_rows(sp_smth_multi, n = 110) - expect_snapshot(print(sp_smth_multi)) - expect_equal( - required_pkgs(sp_smth_multi), - c("mgcv", "probably") - ) - - sl_multi_group <- species_probs |> - dplyr::mutate(group = .pred_bobcat > 0.5) |> - cal_estimate_multinomial(Species, smooth = FALSE, .by = group) - - expect_cal_type(sl_multi_group, "multiclass") - expect_cal_method(sl_multi_group, "Multinomial regression calibration") - expect_cal_rows(sl_multi_group, n = 110) - expect_snapshot(print(sl_multi_group)) - - expect_snapshot_error( - species_probs |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_multinomial(Species, smooth = FALSE, .by = c(group1, group2)) - ) - - mltm_configs <- - mnl_with_configs() |> - cal_estimate_multinomial(truth = obs, estimate = c(VF:L), smooth = FALSE) -}) - -test_that("Multinomial estimates work - tune_results", { - skip_if_not_installed("modeldata") - skip_if_not_installed("nnet") - - tl_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = FALSE) - expect_cal_type(tl_multi, "multiclass") - expect_cal_method(tl_multi, "Multinomial regression calibration") - expect_snapshot(print(tl_multi)) - - expect_equal( - testthat_cal_multiclass() |> - tune::collect_predictions(summarize = TRUE) |> - nrow(), - testthat_cal_multiclass() |> - cal_apply(tl_multi) |> - nrow() - ) - - tl_smth_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = TRUE) - expect_cal_type(tl_smth_multi, "multiclass") - expect_cal_method(tl_smth_multi, "Generalized additive model calibration") - expect_snapshot(print(tl_smth_multi)) - - expect_equal( - testthat_cal_multiclass() |> - tune::collect_predictions(summarize = TRUE) |> - nrow(), - testthat_cal_multiclass() |> - cal_apply(tl_smth_multi) |> - nrow() - ) -}) - -test_that("Multinomial estimates errors - grouped_df", { - skip_if_not_installed("modeldata") - skip_if_not_installed("nnet") - - expect_snapshot_error( - cal_estimate_multinomial(dplyr::group_by(mtcars, vs)) - ) -}) - -test_that("Passing a binary outcome causes error", { - expect_error( - cal_estimate_multinomial(segment_logistic, Class) - ) -}) - -test_that("Linear spline switches to linear if too few unique", { - skip_if_not_installed("modeldata") - - boosting_predictions_oob$.pred <- rep( - x = 1:5, - length.out = nrow(boosting_predictions_oob) - ) - - expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = TRUE) - ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) - - expect_identical( - sl_gam$estimate, - sl_lm$estimate - ) - - expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = TRUE) - ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = FALSE) - - expect_identical( - sl_gam$estimate, - sl_lm$estimate - ) -}) - -test_that("Multinomial spline switches to linear if too few unique", { - skip_if_not_installed("modeldata") - - smol_species_probs <- - species_probs |> - dplyr::slice_head(n = 2, by = Species) - - expect_snapshot( - sl_gam <- cal_estimate_multinomial(smol_species_probs, Species, smooth = TRUE) - ) - sl_glm <- cal_estimate_multinomial(smol_species_probs, Species, smooth = FALSE) - - expect_identical( - sl_gam$estimates, - sl_glm$estimates - ) - - smol_by_species_probs <- - species_probs |> - dplyr::slice_head(n = 4, by = Species) |> - dplyr::mutate(id = rep(1:2, 6)) - - expect_snapshot( - sl_gam <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = TRUE) - ) - sl_glm <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = FALSE) - - expect_identical( - sl_gam$estimates, - sl_glm$estimates - ) -}) - -# --------------------------------- Linear ----------------------------------- -test_that("Linear estimates work - data.frame", { - skip_if_not_installed("modeldata") - - sl_linear <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) - expect_cal_type(sl_linear, "regression") - expect_cal_method(sl_linear, "Linear calibration") - expect_cal_estimate(sl_linear, "butchered_glm") - expect_cal_rows(sl_linear, 2000) - expect_snapshot(print(sl_linear)) - expect_equal( - required_pkgs(sl_linear), - c("probably") - ) - - sl_linear_group <- boosting_predictions_oob |> - dplyr::mutate(group = .pred > 0.5) |> - cal_estimate_linear(outcome, smooth = FALSE, .by = group) - - expect_cal_type(sl_linear_group, "regression") - expect_cal_method(sl_linear_group, "Linear calibration") - expect_cal_estimate(sl_linear_group, "butchered_glm") - expect_cal_rows(sl_linear_group, 2000) - expect_snapshot(print(sl_linear_group)) - - expect_snapshot_error( - boosting_predictions_oob |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_linear(outcome, smooth = FALSE, .by = c(group1, group2)) - ) - -}) - -test_that("Linear estimates work - tune_results", { - tl_linear <- cal_estimate_linear(testthat_cal_reg(), outcome, smooth = FALSE) - expect_cal_type(tl_linear, "regression") - expect_cal_method(tl_linear, "Linear calibration") - expect_cal_estimate(tl_linear, "butchered_glm") - expect_snapshot(print(tl_linear)) - -}) - -test_that("Linear estimates errors - grouped_df", { - expect_snapshot_error( - cal_estimate_linear(dplyr::group_by(mtcars, vs)) - ) -}) - -# ----------------------------- Linear Spline -------------------------------- -test_that("Linear spline estimates work - data.frame", { - skip_if_not_installed("modeldata") - - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome) - expect_cal_type(sl_gam, "regression") - expect_cal_method(sl_gam, "Generalized additive model calibration") - expect_cal_estimate(sl_gam, "butchered_gam") - expect_cal_rows(sl_gam, 2000) - expect_snapshot(print(sl_gam)) - expect_equal( - required_pkgs(sl_gam), - c("mgcv", "probably") - ) - - sl_gam_group <- boosting_predictions_oob |> - dplyr::mutate(group = .pred > 0.5) |> - cal_estimate_linear(outcome, .by = group) - - expect_cal_type(sl_gam_group, "regression") - expect_cal_method(sl_gam_group, "Generalized additive model calibration") - expect_cal_estimate(sl_gam_group, "butchered_gam") - expect_cal_rows(sl_gam_group, 2000) - expect_snapshot(print(sl_gam_group)) - - expect_snapshot_error( - boosting_predictions_oob |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_estimate_linear(outcome, .by = c(group1, group2)) - ) -}) - -test_that("Linear spline estimates work - tune_results", { - tl_gam <- cal_estimate_linear(testthat_cal_reg(), outcome) - expect_cal_type(tl_gam, "regression") - expect_cal_method(tl_gam, "Generalized additive model calibration") - expect_cal_estimate(tl_gam, "butchered_gam") - expect_snapshot(print(tl_gam)) - - expect_equal( - testthat_cal_reg_count(), - nrow(cal_apply(testthat_cal_reg(), tl_gam)) - ) -}) - -test_that("Linear spline switches to linear if too few unique", { - skip_if_not_installed("modeldata") - - boosting_predictions_oob$.pred <- rep( - x = 1:5, - length.out = nrow(boosting_predictions_oob) - ) - - expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = TRUE) - ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) - - expect_identical( - sl_gam$estimates[[1]]$estimate[[1]], - sl_lm$estimates[[1]]$estimate[[1]] - ) - - expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = TRUE) - ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = FALSE) - - expect_identical( - sl_gam$estimates[[1]]$estimate[[1]], - sl_lm$estimates[[1]]$estimate[[1]] - ) -}) - -# ----------------------------------- Other ------------------------------------ -test_that("Non-default names used for estimate columns", { - skip_if_not_installed("modeldata") - - new_segment <- segment_logistic - colnames(new_segment) <- c("poor", "good", "Class") - - set.seed(100) - expect_snapshot( - cal_estimate_isotonic(new_segment, Class, c(good, poor)) - ) -}) - -test_that("Test exceptions", { - expect_error( - cal_estimate_isotonic(segment_logistic, Class, dplyr::starts_with("bad")) - ) -}) - -test_that("non-standard column names", { - library(dplyr) - # issue 145 - seg <- segment_logistic |> - rename_with(~ paste0(.x, "-1"), matches(".pred")) |> - mutate( - Class = paste0(Class,"-1"), - Class = factor(Class), - .pred_class = ifelse(`.pred_poor-1` >= 0.5, "poor-1", "good-1") - ) - calib <- cal_estimate_isotonic(seg, Class) - new_pred <- cal_apply(seg, calib, pred_class = .pred_class) - expect_named(new_pred, c(".pred_poor-1", ".pred_good-1", "Class", ".pred_class")) - -}) From 36b84a2bc034de56c895c4c9d5c960fb3492d368 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 8 Oct 2025 11:32:48 -0700 Subject: [PATCH 4/6] update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 5e66fed..3afeb66 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ inst/doc .RData .DS_Store docs +tests/testthat/Rplots.pdf From 192ac07803b471e620b77a0afd2c70924234f3ca Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 8 Oct 2025 11:33:01 -0700 Subject: [PATCH 5/6] split up test-cal-plot.R --- tests/testthat/_snaps/cal-plot-breaks.md | 58 ++ tests/testthat/_snaps/cal-plot-logistic.md | 50 ++ .../{cal-plot.md => cal-plot-regression.md} | 113 --- tests/testthat/_snaps/cal-plot-windowed.md | 5 + tests/testthat/test-cal-plot-breaks.R | 212 ++++++ tests/testthat/test-cal-plot-logistic.R | 180 +++++ tests/testthat/test-cal-plot-regression.R | 148 ++++ tests/testthat/test-cal-plot-windowed.R | 168 +++++ tests/testthat/test-cal-plot.R | 664 ------------------ 9 files changed, 821 insertions(+), 777 deletions(-) create mode 100644 tests/testthat/_snaps/cal-plot-breaks.md create mode 100644 tests/testthat/_snaps/cal-plot-logistic.md rename tests/testthat/_snaps/{cal-plot.md => cal-plot-regression.md} (51%) create mode 100644 tests/testthat/_snaps/cal-plot-windowed.md create mode 100644 tests/testthat/test-cal-plot-breaks.R create mode 100644 tests/testthat/test-cal-plot-logistic.R create mode 100644 tests/testthat/test-cal-plot-regression.R create mode 100644 tests/testthat/test-cal-plot-windowed.R delete mode 100644 tests/testthat/test-cal-plot.R diff --git a/tests/testthat/_snaps/cal-plot-breaks.md b/tests/testthat/_snaps/cal-plot-breaks.md new file mode 100644 index 0000000..9ebc883 --- /dev/null +++ b/tests/testthat/_snaps/cal-plot-breaks.md @@ -0,0 +1,58 @@ +# Binary breaks functions work with group argument + + Code + get_labs(res) + Output + $colour + [1] "id" + + $fill + [1] "id" + + $x.sec + NULL + + $x + [1] "Bin Midpoint" + + $y + [1] "Event Rate" + + $y.sec + NULL + + $intercept + [1] "intercept" + + $slope + [1] "slope" + + $ymin + [1] "lower" + + $ymax + [1] "upper" + + $alt + [1] "" + + +--- + + x `.by` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + +# breaks plot function errors - grouped_df + + x This function does not work with grouped data frames. + i Apply `dplyr::ungroup()` and use the `.by` argument. + +# Event level handling works + + i In argument: `res = map(...)`. + Caused by error in `map()`: + i In index: 1. + Caused by error: + ! Invalid `event_level` entry: invalid. Valid entries are "first", "second", or "auto". + diff --git a/tests/testthat/_snaps/cal-plot-logistic.md b/tests/testthat/_snaps/cal-plot-logistic.md new file mode 100644 index 0000000..f36bca6 --- /dev/null +++ b/tests/testthat/_snaps/cal-plot-logistic.md @@ -0,0 +1,50 @@ +# Binary logistic functions work with group argument + + Code + get_labs(res) + Output + $colour + [1] "id" + + $fill + [1] "id" + + $x.sec + NULL + + $x + [1] "Probability" + + $y + [1] "Predicted Event Rate" + + $y.sec + NULL + + $intercept + [1] "intercept" + + $slope + [1] "slope" + + $ymin + [1] "lower" + + $ymax + [1] "upper" + + $alt + [1] "" + + +--- + + x `.by` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + +# logistic plot function errors - grouped_df + + x This function does not work with grouped data frames. + i Apply `dplyr::ungroup()` and use the `.by` argument. + diff --git a/tests/testthat/_snaps/cal-plot.md b/tests/testthat/_snaps/cal-plot-regression.md similarity index 51% rename from tests/testthat/_snaps/cal-plot.md rename to tests/testthat/_snaps/cal-plot-regression.md index c0fba43..76cfad8 100644 --- a/tests/testthat/_snaps/cal-plot.md +++ b/tests/testthat/_snaps/cal-plot-regression.md @@ -1,116 +1,3 @@ -# Binary breaks functions work with group argument - - Code - get_labs(res) - Output - $colour - [1] "id" - - $fill - [1] "id" - - $x.sec - NULL - - $x - [1] "Bin Midpoint" - - $y - [1] "Event Rate" - - $y.sec - NULL - - $intercept - [1] "intercept" - - $slope - [1] "slope" - - $ymin - [1] "lower" - - $ymax - [1] "upper" - - $alt - [1] "" - - ---- - - x `.by` cannot select more than one column. - i The following 2 columns were selected: - i group1 and group2 - -# breaks plot function errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Binary logistic functions work with group argument - - Code - get_labs(res) - Output - $colour - [1] "id" - - $fill - [1] "id" - - $x.sec - NULL - - $x - [1] "Probability" - - $y - [1] "Predicted Event Rate" - - $y.sec - NULL - - $intercept - [1] "intercept" - - $slope - [1] "slope" - - $ymin - [1] "lower" - - $ymax - [1] "upper" - - $alt - [1] "" - - ---- - - x `.by` cannot select more than one column. - i The following 2 columns were selected: - i group1 and group2 - -# logistic plot function errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# windowed plot function errors - grouped_df - - x This function does not work with grouped data frames. - i Apply `dplyr::ungroup()` and use the `.by` argument. - -# Event level handling works - - i In argument: `res = map(...)`. - Caused by error in `map()`: - i In index: 1. - Caused by error: - ! Invalid `event_level` entry: invalid. Valid entries are "first", "second", or "auto". - # regression functions work Code diff --git a/tests/testthat/_snaps/cal-plot-windowed.md b/tests/testthat/_snaps/cal-plot-windowed.md new file mode 100644 index 0000000..743fcf4 --- /dev/null +++ b/tests/testthat/_snaps/cal-plot-windowed.md @@ -0,0 +1,5 @@ +# windowed plot function errors - grouped_df + + x This function does not work with grouped data frames. + i Apply `dplyr::ungroup()` and use the `.by` argument. + diff --git a/tests/testthat/test-cal-plot-breaks.R b/tests/testthat/test-cal-plot-breaks.R new file mode 100644 index 0000000..d2ed6e0 --- /dev/null +++ b/tests/testthat/test-cal-plot-breaks.R @@ -0,0 +1,212 @@ +test_that("Binary breaks functions work", { + x10 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "first") + + expect_equal( + x10$predicted_midpoint, + seq(0.05, 0.95, by = 0.10) + ) + + expect_s3_class( + cal_plot_breaks(segment_logistic, Class, .pred_good), + "ggplot" + ) + + x11 <- .cal_table_breaks(testthat_cal_binary()) + + expect_equal( + x11$predicted_midpoint, + rep(seq(0.05, 0.95, by = 0.10), times = 8) + ) + + expect_s3_class( + cal_plot_breaks(testthat_cal_binary()), + "ggplot" + ) + + brks_configs <- + bin_with_configs() |> cal_plot_breaks(truth = Class, estimate = .pred_good) + expect_true(has_facet(brks_configs)) +}) + + +test_that("Binary breaks functions work with group argument", { + skip_if_not_installed("ggplot2", minimum_version = "3.5.2.9000") + res <- segment_logistic |> + dplyr::mutate(id = dplyr::row_number() %% 2) |> + cal_plot_breaks(Class, .pred_good, .by = id) + + expect_s3_class(res, "ggplot") + + expect_equal( + res$data[0,], + dplyr::tibble( + id = factor(0, levels = paste(0:1)), + predicted_midpoint = double(), event_rate = double(), events = double(), + total = integer(), lower = double(), upper = double() + ) + ) + + expect_equal( + rlang::expr_text(res$mapping$x), + "~predicted_midpoint" + ) + expect_equal( + rlang::expr_text(res$mapping$colour), + "~id" + ) + expect_equal( + rlang::expr_text(res$mapping$fill), + "~id" + ) + + expect_snapshot(get_labs(res)) + + expect_equal(length(res$layers), 4) + + expect_snapshot_error( + segment_logistic |> + dplyr::mutate(group1 = 1, group2 = 2) |> + cal_plot_breaks(Class, .pred_good, .by = c(group1, group2)) + ) +}) + +test_that("Multi-class breaks functions work", { + skip_if_not_installed("modeldata") + + x10 <- .cal_table_breaks(species_probs, Species, dplyr::starts_with(".pred")) + + expect_equal( + x10$predicted_midpoint, + rep(seq(0.05, 0.95, by = 0.10), times = 3) + ) + + expect_s3_class( + cal_plot_breaks(species_probs, Species), + "ggplot" + ) + + x11 <- .cal_table_breaks(testthat_cal_multiclass()) + + expect_equal( + sort(unique(x11$predicted_midpoint)), + seq(0.05, 0.95, by = 0.10) + ) + + multi_configs <- cal_plot_breaks(testthat_cal_multiclass()) + # should be faceted by .config and class + expect_s3_class(multi_configs, "ggplot") + expect_true(inherits(multi_configs$facet, "FacetGrid")) + + expect_error( + cal_plot_breaks(species_probs, Species, event_level = "second") + ) + + # ------------------------------------------------------------------------------ + # multinomial outcome, binary logistic plots + + multi_configs_from_tune <- + testthat_cal_multiclass() |> cal_plot_breaks() + expect_s3_class(multi_configs_from_tune, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) + + multi_configs_from_df <- + mnl_with_configs() |> cal_plot_breaks(truth = obs, estimate = c(VF:L)) + expect_s3_class(multi_configs_from_df, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) +}) + + +test_that("breaks plot function errors - grouped_df", { + expect_snapshot_error( + cal_plot_breaks(dplyr::group_by(mtcars, vs)) + ) +}) + +test_that("Numeric groups are supported", { + grp_df <- segment_logistic + grp_df$num_group <- rep(c(1, 2), times = 505) + + p <- grp_df |> + cal_plot_breaks(Class, .pred_good, .by = num_group) + + expect_s3_class(p, "ggplot") +}) + +test_that("Some general exceptions", { + expect_error( + .cal_table_breaks(tune::ames_grid_search), + "The `tune_results` object does not contain columns with predictions" + ) + expect_warning( + cal_plot_breaks(segment_logistic, Class), + ) +}) + +test_that("don't facet if there is only one .config", { + class_data <- testthat_cal_binary() + + class_data$.predictions <- lapply( + class_data$.predictions, + function(x) dplyr::filter(x, .config == "Preprocessor1_Model1") + ) + + res_breaks <- cal_plot_breaks(class_data) + + expect_null(res_breaks$data[[".config"]]) + expect_s3_class(res_breaks, "ggplot") +}) + +test_that("custom names for cal_plot_breaks()", { + data(segment_logistic) + segment_logistic_1 <- dplyr::rename(segment_logistic, good_prob = .pred_good) + p <- cal_plot_breaks(segment_logistic_1, Class, good_prob) + expect_s3_class(p, "ggplot") +}) + +test_that("Event level handling works", { + x7 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "second") + expect_equal( + which(x7$predicted_midpoint == min(x7$predicted_midpoint)), + which(x7$event_rate == max(x7$event_rate)) + ) + + expect_snapshot_error( + .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "invalid") + ) +}) + +test_that("Groupings that may not match work", { + model <- glm(Class ~ .pred_good, segment_logistic, family = "binomial") + + preds <- 1 - predict(model, segment_logistic, type = "response") + + combined <- dplyr::bind_rows( + dplyr::mutate(segment_logistic, source = "original"), + dplyr::mutate(segment_logistic, .pred_good = preds, source = "glm") + ) + + x50 <- combined |> + dplyr::group_by(source) |> + .cal_table_breaks(Class, .pred_good) + + expect_equal( + unique(x50$predicted_midpoint), + seq(0.05, 0.95, by = 0.10) + ) +}) + +test_that("Groups are respected", { + preds <- segment_logistic |> + dplyr::mutate(source = "logistic") |> + dplyr::bind_rows(segment_naive_bayes) |> + dplyr::mutate(source = ifelse(is.na(source), "nb", source)) |> + dplyr::group_by(source) + + x40 <- .cal_table_breaks(preds, Class, .pred_good) + + expect_equal(as.integer(table(x40$source)), c(10, 10)) + + expect_equal(unique(x40$source), c("logistic", "nb")) +}) diff --git a/tests/testthat/test-cal-plot-logistic.R b/tests/testthat/test-cal-plot-logistic.R new file mode 100644 index 0000000..394cb53 --- /dev/null +++ b/tests/testthat/test-cal-plot-logistic.R @@ -0,0 +1,180 @@ +test_that("Binary logistic functions work", { + skip_if_not_installed("modeldata") + + x20 <- .cal_table_logistic(segment_logistic, Class, .pred_good) + + model20 <- mgcv::gam(Class ~ s(.pred_good, k = 10), + data = segment_logistic, + family = binomial() + ) + + preds20 <- predict(model20, + data.frame(.pred_good = seq(0, 1, by = .01)), + type = "response" + ) + + expect_equal(sd(x20$prob), sd(preds20), tolerance = 0.000001) + expect_equal(mean(x20$prob), mean(1 - preds20), tolerance = 0.000001) + + x21 <- cal_plot_logistic(segment_logistic, Class, .pred_good) + + expect_s3_class(x21, "ggplot") + expect_false(has_facet(x21)) + + x22 <- .cal_table_logistic(testthat_cal_binary()) + + + x22_1 <- testthat_cal_binary() |> + tune::collect_predictions(summarize = TRUE) |> + dplyr::group_by(.config) |> + dplyr::group_map(~ { + model <- mgcv::gam( + class ~ s(.pred_class_1, k = 10), + data = .x, + family = binomial() + ) + preds <- predict(model, + data.frame(.pred_class_1 = seq(0, 1, by = .01)), + type = "response" + ) + 1 - preds + }) |> + purrr::reduce(c) + + expect_equal(sd(x22$prob), sd(x22_1), tolerance = 0.000001) + expect_equal(mean(x22$prob), mean(x22_1), tolerance = 0.000001) + + x23 <- cal_plot_logistic(testthat_cal_binary()) + + expect_s3_class(x23, "ggplot") + expect_true(has_facet(x23)) + + x24 <- .cal_table_logistic(segment_logistic, Class, .pred_good, smooth = FALSE) + + model24 <- stats::glm(Class ~ .pred_good, data = segment_logistic, family = binomial()) + + preds24 <- predict(model24, + data.frame(.pred_good = seq(0, 1, by = .01)), + type = "response" + ) + + expect_equal(sd(x24$prob), sd(preds24), tolerance = 0.000001) + expect_equal(mean(x24$prob), mean(1 - preds24), tolerance = 0.000001) + + x25 <- .cal_table_logistic( + segment_logistic, + Class, + .pred_poor, + event_level = "second" + ) + + expect_equal( + which(x25$prob == max(x25$prob)), + nrow(x25) + ) + + lgst_configs <- + bin_with_configs() |> cal_plot_logistic(truth = Class, estimate = .pred_good) + expect_true(has_facet(lgst_configs)) + + # ------------------------------------------------------------------------------ + # multinomial outcome, binary logistic plots + + multi_configs_from_tune <- + testthat_cal_multiclass() |> cal_plot_logistic(smooth = FALSE) + expect_s3_class(multi_configs_from_tune, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) + + + multi_configs_from_df <- + mnl_with_configs() |> cal_plot_logistic(truth = obs, estimate = c(VF:L)) + expect_s3_class(multi_configs_from_df, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) +}) + +test_that("Binary logistic functions work with group argument", { + skip_if_not_installed("ggplot2", minimum_version = "3.5.2.9000") + res <- segment_logistic |> + dplyr::mutate(id = dplyr::row_number() %% 2) |> + cal_plot_logistic(Class, .pred_good, .by = id) + + expect_s3_class( + res, + "ggplot" + ) + expect_true(has_facet(res)) + + expect_s3_class(res, "ggplot") + + expect_equal( + res$data[0,], + dplyr::tibble( + id = factor(0, levels = paste(0:1)), + estimate = double(), prob = double(), lower = double(), upper = double() + ) + ) + + expect_equal( + rlang::expr_text(res$mapping$x), + "~estimate" + ) + expect_equal( + rlang::expr_text(res$mapping$colour), + "~id" + ) + expect_equal( + rlang::expr_text(res$mapping$fill), + "~id" + ) + + expect_snapshot(get_labs(res)) + + expect_equal(length(res$layers), 3) + + expect_snapshot_error( + segment_logistic |> + dplyr::mutate(group1 = 1, group2 = 2) |> + cal_plot_logistic(Class, .pred_good, .by = c(group1, group2)) + ) + + lgst_configs <- + bin_with_configs() |> cal_plot_logistic(truth = Class, estimate = .pred_good) + expect_true(has_facet(lgst_configs)) +}) + +test_that("logistic plot function errors - grouped_df", { + expect_snapshot_error( + cal_plot_logistic(dplyr::group_by(mtcars, vs)) + ) +}) + +test_that("don't facet if there is only one .config", { + class_data <- testthat_cal_binary() + + class_data$.predictions <- lapply( + class_data$.predictions, + function(x) dplyr::filter(x, .config == "Preprocessor1_Model1") + ) + + res_logistic <- cal_plot_logistic(class_data) + + expect_null(res_logistic$data[[".config"]]) + expect_s3_class(res_logistic, "ggplot") +}) + + +test_that("Groups are respected", { + preds <- segment_logistic |> + dplyr::mutate(source = "logistic") |> + dplyr::bind_rows(segment_naive_bayes) |> + dplyr::mutate(source = ifelse(is.na(source), "nb", source)) |> + dplyr::group_by(source) + + x41 <- .cal_table_logistic(preds, Class, .pred_good) + + expect_equal(as.integer(table(x41$source)), c(101, 101)) + + expect_equal(unique(x41$source), c("logistic", "nb")) +}) diff --git a/tests/testthat/test-cal-plot-regression.R b/tests/testthat/test-cal-plot-regression.R new file mode 100644 index 0000000..cc49aa6 --- /dev/null +++ b/tests/testthat/test-cal-plot-regression.R @@ -0,0 +1,148 @@ +test_that("regression functions work", { + skip_if_not_installed("ggplot2", minimum_version = "3.5.2.9000") + skip_if(R.version[["arch"]] != "aarch64") # see note below + + obj <- testthat_cal_reg() + + res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred) + expect_s3_class(res, "ggplot") + + expect_equal( + res$data[0,], + dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0)) + ) + + expect_equal( + rlang::expr_text(res$mapping$x), + "~outcome" + ) + expect_equal( + rlang::expr_text(res$mapping$y), + "~.pred" + ) + expect_null(res$mapping$colour) + expect_null(res$mapping$fill) + + expect_snapshot(get_labs(res)) + + expect_equal(length(res$layers), 3) + + res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred, .by = id) + expect_s3_class(res, "ggplot") + + expect_equal( + res$data[0,], + dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0)) + ) + + expect_equal( + rlang::expr_text(res$mapping$x), + "~outcome" + ) + expect_equal( + rlang::expr_text(res$mapping$y), + "~.pred" + ) + expect_null(res$mapping$colour) + expect_null(res$mapping$fill) + + expect_snapshot(get_labs(res)) + + expect_equal(length(res$layers), 3) + + res <- cal_plot_regression(obj) + expect_s3_class(res, "ggplot") + + skip_if_not_installed("tune", "1.2.0") + expect_equal( + res$data[0,], + dplyr::tibble(.pred = numeric(0), .row = numeric(0), + predictor_01 = integer(0), outcome = numeric(0), + .config = character()) + ) + + expect_equal( + rlang::expr_text(res$mapping$x), + "~outcome" + ) + expect_equal( + rlang::expr_text(res$mapping$y), + "~.pred" + ) + expect_null(res$mapping$colour) + expect_null(res$mapping$fill) + + expect_snapshot(get_labs(res)) + + expect_equal(length(res$layers), 3) + + res <- print(cal_plot_regression(obj), alpha = 1 / 5, smooth = FALSE) + expect_s3_class(res, "ggplot") + + skip_if_not_installed("tune", "1.2.0") + expect_equal( + res$data[0,], + dplyr::tibble(.pred = numeric(0), .row = numeric(0), + predictor_01 = integer(0), outcome = numeric(0), + .config = character()) + ) + + expect_equal( + rlang::expr_text(res$mapping$x), + "~outcome" + ) + expect_equal( + rlang::expr_text(res$mapping$y), + "~.pred" + ) + expect_null(res$mapping$colour) + expect_null(res$mapping$fill) + + expect_snapshot(get_labs(res)) + + expect_equal(length(res$layers), 3) + + res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred, smooth = FALSE) + expect_s3_class(res, "ggplot") + + expect_equal( + res$data[0,], + dplyr::tibble(outcome = numeric(0), .pred = numeric(0), + id = character()) + ) + + expect_equal( + rlang::expr_text(res$mapping$x), + "~outcome" + ) + expect_equal( + rlang::expr_text(res$mapping$y), + "~.pred" + ) + expect_null(res$mapping$colour) + expect_null(res$mapping$fill) + + expect_snapshot(get_labs(res)) + + expect_equal(length(res$layers), 3) +}) + +test_that("regression plot function errors - grouped_df", { + expect_snapshot_error( + cal_plot_regression(dplyr::group_by(mtcars, vs)) + ) +}) + +test_that("don't facet if there is only one .config", { + reg_data <- testthat_cal_reg() + + reg_data$.predictions <- lapply( + reg_data$.predictions, + function(x) dplyr::filter(x, .config == "Preprocessor01_Model1") + ) + + res_regression <- cal_plot_regression(reg_data) + + expect_null(res_regression$data[[".config"]]) + expect_s3_class(res_regression, "ggplot") +}) diff --git a/tests/testthat/test-cal-plot-windowed.R b/tests/testthat/test-cal-plot-windowed.R new file mode 100644 index 0000000..ce77ac1 --- /dev/null +++ b/tests/testthat/test-cal-plot-windowed.R @@ -0,0 +1,168 @@ + +test_that("Binary windowed functions work", { + skip_if_not_installed("modeldata") + + x30 <- .cal_table_windowed( + segment_logistic, + truth = Class, + estimate = .pred_good, + step_size = 0.11, + window_size = 0.10 + ) + + x30_1 <- segment_logistic |> + dplyr::mutate(x = dplyr::case_when( + .pred_good <= 0.05 ~ 1, + .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, + .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, + .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, + .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, + .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, + .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, + .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, + .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, + .pred_good >= 0.94 & .pred_good <= 1 ~ 10, + )) |> + dplyr::filter(!is.na(x)) |> + dplyr::count(x) + + expect_equal( + x30$total, + x30_1$n + ) + + x31 <- cal_plot_windowed(segment_logistic, Class, .pred_good) + + expect_s3_class(x31, "ggplot") + expect_false(has_facet(x31)) + + x32 <- .cal_table_windowed( + testthat_cal_binary(), + step_size = 0.11, + window_size = 0.10 + ) + + x32_1 <- testthat_cal_binary() |> + tune::collect_predictions(summarize = TRUE) |> + dplyr::mutate(x = dplyr::case_when( + .pred_class_1 <= 0.05 ~ 1, + .pred_class_1 >= 0.06 & .pred_class_1 <= 0.16 ~ 2, + .pred_class_1 >= 0.17 & .pred_class_1 <= 0.27 ~ 3, + .pred_class_1 >= 0.28 & .pred_class_1 <= 0.38 ~ 4, + .pred_class_1 >= 0.39 & .pred_class_1 <= 0.49 ~ 5, + .pred_class_1 >= 0.50 & .pred_class_1 <= 0.60 ~ 6, + .pred_class_1 >= 0.61 & .pred_class_1 <= 0.71 ~ 7, + .pred_class_1 >= 0.72 & .pred_class_1 <= 0.82 ~ 8, + .pred_class_1 >= 0.83 & .pred_class_1 <= 0.93 ~ 9, + .pred_class_1 >= 0.94 & .pred_class_1 <= 1 ~ 10, + )) |> + dplyr::filter(!is.na(x)) |> + dplyr::count(.config, x) + + expect_equal( + x32$total, + x32_1$n + ) + + x33 <- cal_plot_windowed(testthat_cal_binary()) + + expect_s3_class(x33, "ggplot") + expect_true(has_facet(x33)) + + win_configs <- + bin_with_configs() |> cal_plot_windowed(truth = Class, estimate = .pred_good) + expect_true(has_facet(win_configs)) + + + # ------------------------------------------------------------------------------ + # multinomial outcome, binary windowed plots + + multi_configs_from_tune <- + testthat_cal_multiclass() |> cal_plot_windowed() + expect_s3_class(multi_configs_from_tune, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) + + + multi_configs_from_df <- + mnl_with_configs() |> cal_plot_windowed(truth = obs, estimate = c(VF:L)) + expect_s3_class(multi_configs_from_df, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) +}) + +test_that("windowed plot function errors - grouped_df", { + expect_snapshot_error( + cal_plot_windowed(dplyr::group_by(mtcars, vs)) + ) +}) + +test_that("don't facet if there is only one .config", { + class_data <- testthat_cal_binary() + + class_data$.predictions <- lapply( + class_data$.predictions, + function(x) dplyr::filter(x, .config == "Preprocessor1_Model1") + ) + + res_windowed <- cal_plot_windowed(class_data) + + expect_null(res_windowed$data[[".config"]]) + expect_s3_class(res_windowed, "ggplot") +}) + + +test_that("Groupings that may not match work", { + model <- glm(Class ~ .pred_good, segment_logistic, family = "binomial") + + preds <- 1 - predict(model, segment_logistic, type = "response") + + combined <- dplyr::bind_rows( + dplyr::mutate(segment_logistic, source = "original"), + dplyr::mutate(segment_logistic, .pred_good = preds, source = "glm") + ) + + x51 <- combined |> + dplyr::group_by(source) |> + .cal_table_windowed( + truth = Class, + estimate = .pred_good, + step_size = 0.11, + window_size = 0.10 + ) + + x51_1 <- combined |> + dplyr::mutate(x = dplyr::case_when( + .pred_good <= 0.05 ~ 1, + .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, + .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, + .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, + .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, + .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, + .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, + .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, + .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, + .pred_good >= 0.94 & .pred_good <= 1 ~ 10, + )) |> + dplyr::filter(!is.na(x)) |> + dplyr::count(source, x) + + expect_equal( + x51$total, + x51_1$n + ) +}) + +test_that("Groups are respected", { + preds <- segment_logistic |> + dplyr::mutate(source = "logistic") |> + dplyr::bind_rows(segment_naive_bayes) |> + dplyr::mutate(source = ifelse(is.na(source), "nb", source)) |> + dplyr::group_by(source) + + x42 <- .cal_table_windowed(preds, Class, .pred_good) + + expect_equal(as.integer(table(x42$source)), c(21, 21)) + + expect_equal(unique(x42$source), c("logistic", "nb")) +}) diff --git a/tests/testthat/test-cal-plot.R b/tests/testthat/test-cal-plot.R deleted file mode 100644 index 0f0c82d..0000000 --- a/tests/testthat/test-cal-plot.R +++ /dev/null @@ -1,664 +0,0 @@ -test_that("Binary breaks functions work", { - x10 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "first") - - expect_equal( - x10$predicted_midpoint, - seq(0.05, 0.95, by = 0.10) - ) - - expect_s3_class( - cal_plot_breaks(segment_logistic, Class, .pred_good), - "ggplot" - ) - - x11 <- .cal_table_breaks(testthat_cal_binary()) - - expect_equal( - x11$predicted_midpoint, - rep(seq(0.05, 0.95, by = 0.10), times = 8) - ) - - expect_s3_class( - cal_plot_breaks(testthat_cal_binary()), - "ggplot" - ) - - brks_configs <- - bin_with_configs() |> cal_plot_breaks(truth = Class, estimate = .pred_good) - expect_true(has_facet(brks_configs)) -}) - -test_that("Binary breaks functions work with group argument", { - skip_if_not_installed("ggplot2", minimum_version = "3.5.2.9000") - res <- segment_logistic |> - dplyr::mutate(id = dplyr::row_number() %% 2) |> - cal_plot_breaks(Class, .pred_good, .by = id) - - expect_s3_class(res, "ggplot") - - expect_equal( - res$data[0,], - dplyr::tibble( - id = factor(0, levels = paste(0:1)), - predicted_midpoint = double(), event_rate = double(), events = double(), - total = integer(), lower = double(), upper = double() - ) - ) - - expect_equal( - rlang::expr_text(res$mapping$x), - "~predicted_midpoint" - ) - expect_equal( - rlang::expr_text(res$mapping$colour), - "~id" - ) - expect_equal( - rlang::expr_text(res$mapping$fill), - "~id" - ) - - expect_snapshot(get_labs(res)) - - expect_equal(length(res$layers), 4) - - expect_snapshot_error( - segment_logistic |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_plot_breaks(Class, .pred_good, .by = c(group1, group2)) - ) -}) - -test_that("Multi-class breaks functions work", { - skip_if_not_installed("modeldata") - - x10 <- .cal_table_breaks(species_probs, Species, dplyr::starts_with(".pred")) - - expect_equal( - x10$predicted_midpoint, - rep(seq(0.05, 0.95, by = 0.10), times = 3) - ) - - expect_s3_class( - cal_plot_breaks(species_probs, Species), - "ggplot" - ) - - x11 <- .cal_table_breaks(testthat_cal_multiclass()) - - expect_equal( - sort(unique(x11$predicted_midpoint)), - seq(0.05, 0.95, by = 0.10) - ) - - multi_configs <- cal_plot_breaks(testthat_cal_multiclass()) - # should be faceted by .config and class - expect_s3_class(multi_configs, "ggplot") - expect_true(inherits(multi_configs$facet, "FacetGrid")) - - expect_error( - cal_plot_breaks(species_probs, Species, event_level = "second") - ) - - # ------------------------------------------------------------------------------ - # multinomial outcome, binary logistic plots - - multi_configs_from_tune <- - testthat_cal_multiclass() |> cal_plot_breaks() - expect_s3_class(multi_configs_from_tune, "ggplot") - # should be faceted by .config and class - expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) - - multi_configs_from_df <- - mnl_with_configs() |> cal_plot_breaks(truth = obs, estimate = c(VF:L)) - expect_s3_class(multi_configs_from_df, "ggplot") - # should be faceted by .config and class - expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) -}) - -test_that("breaks plot function errors - grouped_df", { - expect_snapshot_error( - cal_plot_breaks(dplyr::group_by(mtcars, vs)) - ) -}) - -test_that("Binary logistic functions work", { - skip_if_not_installed("modeldata") - - x20 <- .cal_table_logistic(segment_logistic, Class, .pred_good) - - model20 <- mgcv::gam(Class ~ s(.pred_good, k = 10), - data = segment_logistic, - family = binomial() - ) - - preds20 <- predict(model20, - data.frame(.pred_good = seq(0, 1, by = .01)), - type = "response" - ) - - expect_equal(sd(x20$prob), sd(preds20), tolerance = 0.000001) - expect_equal(mean(x20$prob), mean(1 - preds20), tolerance = 0.000001) - - x21 <- cal_plot_logistic(segment_logistic, Class, .pred_good) - - expect_s3_class(x21, "ggplot") - expect_false(has_facet(x21)) - - x22 <- .cal_table_logistic(testthat_cal_binary()) - - - x22_1 <- testthat_cal_binary() |> - tune::collect_predictions(summarize = TRUE) |> - dplyr::group_by(.config) |> - dplyr::group_map(~ { - model <- mgcv::gam( - class ~ s(.pred_class_1, k = 10), - data = .x, - family = binomial() - ) - preds <- predict(model, - data.frame(.pred_class_1 = seq(0, 1, by = .01)), - type = "response" - ) - 1 - preds - }) |> - purrr::reduce(c) - - expect_equal(sd(x22$prob), sd(x22_1), tolerance = 0.000001) - expect_equal(mean(x22$prob), mean(x22_1), tolerance = 0.000001) - - x23 <- cal_plot_logistic(testthat_cal_binary()) - - expect_s3_class(x23, "ggplot") - expect_true(has_facet(x23)) - - x24 <- .cal_table_logistic(segment_logistic, Class, .pred_good, smooth = FALSE) - - model24 <- stats::glm(Class ~ .pred_good, data = segment_logistic, family = binomial()) - - preds24 <- predict(model24, - data.frame(.pred_good = seq(0, 1, by = .01)), - type = "response" - ) - - expect_equal(sd(x24$prob), sd(preds24), tolerance = 0.000001) - expect_equal(mean(x24$prob), mean(1 - preds24), tolerance = 0.000001) - - x25 <- .cal_table_logistic( - segment_logistic, - Class, - .pred_poor, - event_level = "second" - ) - - expect_equal( - which(x25$prob == max(x25$prob)), - nrow(x25) - ) - - lgst_configs <- - bin_with_configs() |> cal_plot_logistic(truth = Class, estimate = .pred_good) - expect_true(has_facet(lgst_configs)) - - # ------------------------------------------------------------------------------ - # multinomial outcome, binary logistic plots - - multi_configs_from_tune <- - testthat_cal_multiclass() |> cal_plot_logistic(smooth = FALSE) - expect_s3_class(multi_configs_from_tune, "ggplot") - # should be faceted by .config and class - expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) - - - multi_configs_from_df <- - mnl_with_configs() |> cal_plot_logistic(truth = obs, estimate = c(VF:L)) - expect_s3_class(multi_configs_from_df, "ggplot") - # should be faceted by .config and class - expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) -}) - -test_that("Binary logistic functions work with group argument", { - skip_if_not_installed("ggplot2", minimum_version = "3.5.2.9000") - res <- segment_logistic |> - dplyr::mutate(id = dplyr::row_number() %% 2) |> - cal_plot_logistic(Class, .pred_good, .by = id) - - expect_s3_class( - res, - "ggplot" - ) - expect_true(has_facet(res)) - - expect_s3_class(res, "ggplot") - - expect_equal( - res$data[0,], - dplyr::tibble( - id = factor(0, levels = paste(0:1)), - estimate = double(), prob = double(), lower = double(), upper = double() - ) - ) - - expect_equal( - rlang::expr_text(res$mapping$x), - "~estimate" - ) - expect_equal( - rlang::expr_text(res$mapping$colour), - "~id" - ) - expect_equal( - rlang::expr_text(res$mapping$fill), - "~id" - ) - - expect_snapshot(get_labs(res)) - - expect_equal(length(res$layers), 3) - - expect_snapshot_error( - segment_logistic |> - dplyr::mutate(group1 = 1, group2 = 2) |> - cal_plot_logistic(Class, .pred_good, .by = c(group1, group2)) - ) - - lgst_configs <- - bin_with_configs() |> cal_plot_logistic(truth = Class, estimate = .pred_good) - expect_true(has_facet(lgst_configs)) -}) - -test_that("logistic plot function errors - grouped_df", { - expect_snapshot_error( - cal_plot_logistic(dplyr::group_by(mtcars, vs)) - ) -}) - -test_that("Binary windowed functions work", { - skip_if_not_installed("modeldata") - - x30 <- .cal_table_windowed( - segment_logistic, - truth = Class, - estimate = .pred_good, - step_size = 0.11, - window_size = 0.10 - ) - - x30_1 <- segment_logistic |> - dplyr::mutate(x = dplyr::case_when( - .pred_good <= 0.05 ~ 1, - .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, - .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, - .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, - .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, - .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, - .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, - .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, - .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, - .pred_good >= 0.94 & .pred_good <= 1 ~ 10, - )) |> - dplyr::filter(!is.na(x)) |> - dplyr::count(x) - - expect_equal( - x30$total, - x30_1$n - ) - - x31 <- cal_plot_windowed(segment_logistic, Class, .pred_good) - - expect_s3_class(x31, "ggplot") - expect_false(has_facet(x31)) - - x32 <- .cal_table_windowed( - testthat_cal_binary(), - step_size = 0.11, - window_size = 0.10 - ) - - x32_1 <- testthat_cal_binary() |> - tune::collect_predictions(summarize = TRUE) |> - dplyr::mutate(x = dplyr::case_when( - .pred_class_1 <= 0.05 ~ 1, - .pred_class_1 >= 0.06 & .pred_class_1 <= 0.16 ~ 2, - .pred_class_1 >= 0.17 & .pred_class_1 <= 0.27 ~ 3, - .pred_class_1 >= 0.28 & .pred_class_1 <= 0.38 ~ 4, - .pred_class_1 >= 0.39 & .pred_class_1 <= 0.49 ~ 5, - .pred_class_1 >= 0.50 & .pred_class_1 <= 0.60 ~ 6, - .pred_class_1 >= 0.61 & .pred_class_1 <= 0.71 ~ 7, - .pred_class_1 >= 0.72 & .pred_class_1 <= 0.82 ~ 8, - .pred_class_1 >= 0.83 & .pred_class_1 <= 0.93 ~ 9, - .pred_class_1 >= 0.94 & .pred_class_1 <= 1 ~ 10, - )) |> - dplyr::filter(!is.na(x)) |> - dplyr::count(.config, x) - - expect_equal( - x32$total, - x32_1$n - ) - - x33 <- cal_plot_windowed(testthat_cal_binary()) - - expect_s3_class(x33, "ggplot") - expect_true(has_facet(x33)) - - win_configs <- - bin_with_configs() |> cal_plot_windowed(truth = Class, estimate = .pred_good) - expect_true(has_facet(win_configs)) - - - # ------------------------------------------------------------------------------ - # multinomial outcome, binary windowed plots - - multi_configs_from_tune <- - testthat_cal_multiclass() |> cal_plot_windowed() - expect_s3_class(multi_configs_from_tune, "ggplot") - # should be faceted by .config and class - expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) - - - multi_configs_from_df <- - mnl_with_configs() |> cal_plot_windowed(truth = obs, estimate = c(VF:L)) - expect_s3_class(multi_configs_from_df, "ggplot") - # should be faceted by .config and class - expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) -}) - -test_that("windowed plot function errors - grouped_df", { - expect_snapshot_error( - cal_plot_windowed(dplyr::group_by(mtcars, vs)) - ) -}) - -test_that("Event level handling works", { - x7 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "second") - expect_equal( - which(x7$predicted_midpoint == min(x7$predicted_midpoint)), - which(x7$event_rate == max(x7$event_rate)) - ) - - expect_snapshot_error( - .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "invalid") - ) -}) - - -test_that("Groups are respected", { - preds <- segment_logistic |> - dplyr::mutate(source = "logistic") |> - dplyr::bind_rows(segment_naive_bayes) |> - dplyr::mutate(source = ifelse(is.na(source), "nb", source)) |> - dplyr::group_by(source) - - x40 <- .cal_table_breaks(preds, Class, .pred_good) - - expect_equal(as.integer(table(x40$source)), c(10, 10)) - - expect_equal(unique(x40$source), c("logistic", "nb")) - - x41 <- .cal_table_logistic(preds, Class, .pred_good) - - expect_equal(as.integer(table(x41$source)), c(101, 101)) - - expect_equal(unique(x41$source), c("logistic", "nb")) - - x42 <- .cal_table_windowed(preds, Class, .pred_good) - - expect_equal(as.integer(table(x42$source)), c(21, 21)) - - expect_equal(unique(x42$source), c("logistic", "nb")) -}) - -test_that("Groupings that may not match work", { - model <- glm(Class ~ .pred_good, segment_logistic, family = "binomial") - - preds <- 1 - predict(model, segment_logistic, type = "response") - - combined <- dplyr::bind_rows( - dplyr::mutate(segment_logistic, source = "original"), - dplyr::mutate(segment_logistic, .pred_good = preds, source = "glm") - ) - - x50 <- combined |> - dplyr::group_by(source) |> - .cal_table_breaks(Class, .pred_good) - - expect_equal( - unique(x50$predicted_midpoint), - seq(0.05, 0.95, by = 0.10) - ) - - x51 <- combined |> - dplyr::group_by(source) |> - .cal_table_windowed( - truth = Class, - estimate = .pred_good, - step_size = 0.11, - window_size = 0.10 - ) - - x51_1 <- combined |> - dplyr::mutate(x = dplyr::case_when( - .pred_good <= 0.05 ~ 1, - .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, - .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, - .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, - .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, - .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, - .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, - .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, - .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, - .pred_good >= 0.94 & .pred_good <= 1 ~ 10, - )) |> - dplyr::filter(!is.na(x)) |> - dplyr::count(source, x) - - expect_equal( - x51$total, - x51_1$n - ) -}) - -test_that("Numeric groups are supported", { - grp_df <- segment_logistic - grp_df$num_group <- rep(c(1, 2), times = 505) - - p <- grp_df |> - cal_plot_breaks(Class, .pred_good, .by = num_group) - - expect_s3_class(p, "ggplot") -}) - -test_that("Some general exceptions", { - expect_error( - .cal_table_breaks(tune::ames_grid_search), - "The `tune_results` object does not contain columns with predictions" - ) - expect_warning( - cal_plot_breaks(segment_logistic, Class), - ) -}) - -# ------------------------------------------------------------------------------ - -test_that("regression functions work", { - skip_if_not_installed("ggplot2", minimum_version = "3.5.2.9000") - skip_if(R.version[["arch"]] != "aarch64") # see note below - - obj <- testthat_cal_reg() - - res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred) - expect_s3_class(res, "ggplot") - - expect_equal( - res$data[0,], - dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0)) - ) - - expect_equal( - rlang::expr_text(res$mapping$x), - "~outcome" - ) - expect_equal( - rlang::expr_text(res$mapping$y), - "~.pred" - ) - expect_null(res$mapping$colour) - expect_null(res$mapping$fill) - - expect_snapshot(get_labs(res)) - - expect_equal(length(res$layers), 3) - - res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred, .by = id) - expect_s3_class(res, "ggplot") - - expect_equal( - res$data[0,], - dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0)) - ) - - expect_equal( - rlang::expr_text(res$mapping$x), - "~outcome" - ) - expect_equal( - rlang::expr_text(res$mapping$y), - "~.pred" - ) - expect_null(res$mapping$colour) - expect_null(res$mapping$fill) - - expect_snapshot(get_labs(res)) - - expect_equal(length(res$layers), 3) - - res <- cal_plot_regression(obj) - expect_s3_class(res, "ggplot") - - skip_if_not_installed("tune", "1.2.0") - expect_equal( - res$data[0,], - dplyr::tibble(.pred = numeric(0), .row = numeric(0), - predictor_01 = integer(0), outcome = numeric(0), - .config = character()) - ) - - expect_equal( - rlang::expr_text(res$mapping$x), - "~outcome" - ) - expect_equal( - rlang::expr_text(res$mapping$y), - "~.pred" - ) - expect_null(res$mapping$colour) - expect_null(res$mapping$fill) - - expect_snapshot(get_labs(res)) - - expect_equal(length(res$layers), 3) - - res <- print(cal_plot_regression(obj), alpha = 1 / 5, smooth = FALSE) - expect_s3_class(res, "ggplot") - - skip_if_not_installed("tune", "1.2.0") - expect_equal( - res$data[0,], - dplyr::tibble(.pred = numeric(0), .row = numeric(0), - predictor_01 = integer(0), outcome = numeric(0), - .config = character()) - ) - - expect_equal( - rlang::expr_text(res$mapping$x), - "~outcome" - ) - expect_equal( - rlang::expr_text(res$mapping$y), - "~.pred" - ) - expect_null(res$mapping$colour) - expect_null(res$mapping$fill) - - expect_snapshot(get_labs(res)) - - expect_equal(length(res$layers), 3) - - res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred, smooth = FALSE) - expect_s3_class(res, "ggplot") - - expect_equal( - res$data[0,], - dplyr::tibble(outcome = numeric(0), .pred = numeric(0), - id = character()) - ) - - expect_equal( - rlang::expr_text(res$mapping$x), - "~outcome" - ) - expect_equal( - rlang::expr_text(res$mapping$y), - "~.pred" - ) - expect_null(res$mapping$colour) - expect_null(res$mapping$fill) - - expect_snapshot(get_labs(res)) - - expect_equal(length(res$layers), 3) -}) - -test_that("regression plot function errors - grouped_df", { - expect_snapshot_error( - cal_plot_regression(dplyr::group_by(mtcars, vs)) - ) -}) - -# ------------------------------------------------------------------------------ - -test_that("don't facet if there is only one .config", { - class_data <- testthat_cal_binary() - - class_data$.predictions <- lapply( - class_data$.predictions, - function(x) dplyr::filter(x, .config == "Preprocessor1_Model1") - ) - - res_breaks <- cal_plot_breaks(class_data) - - expect_null(res_breaks$data[[".config"]]) - expect_s3_class(res_breaks, "ggplot") - - res_logistic <- cal_plot_logistic(class_data) - - expect_null(res_logistic$data[[".config"]]) - expect_s3_class(res_logistic, "ggplot") - - res_windowed <- cal_plot_windowed(class_data) - - expect_null(res_windowed$data[[".config"]]) - expect_s3_class(res_windowed, "ggplot") - - reg_data <- testthat_cal_reg() - - reg_data$.predictions <- lapply( - reg_data$.predictions, - function(x) dplyr::filter(x, .config == "Preprocessor01_Model1") - ) - - res_regression <- cal_plot_regression(reg_data) - - expect_null(res_regression$data[[".config"]]) - expect_s3_class(res_regression, "ggplot") -}) - -test_that("custom names for cal_plot_breaks()", { - data(segment_logistic) - segment_logistic_1 <- dplyr::rename(segment_logistic, good_prob = .pred_good) - p <- cal_plot_breaks(segment_logistic_1, Class, good_prob) - expect_s3_class(p, "ggplot") -}) From fa4a2b0dc5ef031ef98926089aae36a763f2dd66 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 8 Oct 2025 11:45:33 -0700 Subject: [PATCH 6/6] move tests around in conformal_infer files --- DESCRIPTION | 2 +- ...nformal_infer.R => conformal_infer_full.R} | 0 man/control_conformal_full.Rd | 2 +- man/int_conformal_full.Rd | 2 +- man/predict.int_conformal_full.Rd | 10 +- man/required_pkgs.int_conformal_cv.Rd | 10 +- tests/testthat/_snaps/conformal_infer_cv.md | 73 +++++++ ...l-intervals.md => conformal_infer_full.md} | 73 ------- ...uantile.md => conformal_infer_quantile.md} | 0 ...vals-split.md => conformal_infer_split.md} | 0 ...{make-class-pred.md => make_class_pred.md} | 0 .../{threshold-perf.md => threshold_perf.md} | 0 ...-intervals.R => test-conformal_infer_cv.R} | 106 ---------- tests/testthat/test-conformal_infer_full.R | 192 ++++++++++++++++++ ...tile.R => test-conformal_infer_quantile.R} | 0 ...s-split.R => test-conformal_infer_split.R} | 0 ...ke-class-pred.R => test-make_class_pred.R} | 0 ...threshold-perf.R => test-threshold_perf.R} | 0 18 files changed, 278 insertions(+), 192 deletions(-) rename R/{conformal_infer.R => conformal_infer_full.R} (100%) create mode 100644 tests/testthat/_snaps/conformal_infer_cv.md rename tests/testthat/_snaps/{conformal-intervals.md => conformal_infer_full.md} (61%) rename tests/testthat/_snaps/{conformal-intervals-quantile.md => conformal_infer_quantile.md} (100%) rename tests/testthat/_snaps/{conformal-intervals-split.md => conformal_infer_split.md} (100%) rename tests/testthat/_snaps/{make-class-pred.md => make_class_pred.md} (100%) rename tests/testthat/_snaps/{threshold-perf.md => threshold_perf.md} (100%) rename tests/testthat/{test-conformal-intervals.R => test-conformal_infer_cv.R} (65%) create mode 100644 tests/testthat/test-conformal_infer_full.R rename tests/testthat/{test-conformal-intervals-quantile.R => test-conformal_infer_quantile.R} (100%) rename tests/testthat/{test-conformal-intervals-split.R => test-conformal_infer_split.R} (100%) rename tests/testthat/{test-make-class-pred.R => test-make_class_pred.R} (100%) rename tests/testthat/{test-threshold-perf.R => test-threshold_perf.R} (100%) diff --git a/DESCRIPTION b/DESCRIPTION index 0709eaa..d4f286e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -84,8 +84,8 @@ Collate: 'cal-utils.R' 'cal-validate.R' 'class-pred.R' - 'conformal_infer.R' 'conformal_infer_cv.R' + 'conformal_infer_full.R' 'conformal_infer_quantile.R' 'conformal_infer_split.R' 'data.R' diff --git a/R/conformal_infer.R b/R/conformal_infer_full.R similarity index 100% rename from R/conformal_infer.R rename to R/conformal_infer_full.R diff --git a/man/control_conformal_full.Rd b/man/control_conformal_full.Rd index a6e1e1d..a730fad 100644 --- a/man/control_conformal_full.Rd +++ b/man/control_conformal_full.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/conformal_infer.R +% Please edit documentation in R/conformal_infer_full.R \name{control_conformal_full} \alias{control_conformal_full} \title{Controlling the numeric details for conformal inference} diff --git a/man/int_conformal_full.Rd b/man/int_conformal_full.Rd index e950fe7..c0cc784 100644 --- a/man/int_conformal_full.Rd +++ b/man/int_conformal_full.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/conformal_infer.R +% Please edit documentation in R/conformal_infer_full.R \name{int_conformal_full} \alias{int_conformal_full} \alias{int_conformal_full.default} diff --git a/man/predict.int_conformal_full.Rd b/man/predict.int_conformal_full.Rd index 6a3a746..2d1c068 100644 --- a/man/predict.int_conformal_full.Rd +++ b/man/predict.int_conformal_full.Rd @@ -1,17 +1,17 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/conformal_infer.R, R/conformal_infer_cv.R, +% Please edit documentation in R/conformal_infer_cv.R, R/conformal_infer_full.R, % R/conformal_infer_quantile.R, R/conformal_infer_split.R -\name{predict.int_conformal_full} -\alias{predict.int_conformal_full} +\name{predict.int_conformal_cv} \alias{predict.int_conformal_cv} +\alias{predict.int_conformal_full} \alias{predict.int_conformal_quantile} \alias{predict.int_conformal_split} \title{Prediction intervals from conformal methods} \usage{ -\method{predict}{int_conformal_full}(object, new_data, level = 0.95, ...) - \method{predict}{int_conformal_cv}(object, new_data, level = 0.95, ...) +\method{predict}{int_conformal_full}(object, new_data, level = 0.95, ...) + \method{predict}{int_conformal_quantile}(object, new_data, ...) \method{predict}{int_conformal_split}(object, new_data, level = 0.95, ...) diff --git a/man/required_pkgs.int_conformal_cv.Rd b/man/required_pkgs.int_conformal_cv.Rd index e5d2d84..050be6c 100644 --- a/man/required_pkgs.int_conformal_cv.Rd +++ b/man/required_pkgs.int_conformal_cv.Rd @@ -1,18 +1,18 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/conformal_infer.R, R/conformal_infer_cv.R, +% Please edit documentation in R/conformal_infer_cv.R, R/conformal_infer_full.R, % R/conformal_infer_quantile.R, R/conformal_infer_split.R -\name{required_pkgs.int_conformal_full} -\alias{required_pkgs.int_conformal_full} +\name{required_pkgs.int_conformal_cv} \alias{required_pkgs.int_conformal_cv} +\alias{required_pkgs.int_conformal_full} \alias{required_pkgs.int_conformal_quantile} \alias{required_pkgs.int_conformal_split} \title{S3 methods to track which additional packages are needed for prediction intervals via conformal inference} \usage{ -\method{required_pkgs}{int_conformal_full}(x, infra = TRUE, ...) - \method{required_pkgs}{int_conformal_cv}(x, infra = TRUE, ...) +\method{required_pkgs}{int_conformal_full}(x, infra = TRUE, ...) + \method{required_pkgs}{int_conformal_quantile}(x, infra = TRUE, ...) \method{required_pkgs}{int_conformal_split}(x, infra = TRUE, ...) diff --git a/tests/testthat/_snaps/conformal_infer_cv.md b/tests/testthat/_snaps/conformal_infer_cv.md new file mode 100644 index 0000000..8754694 --- /dev/null +++ b/tests/testthat/_snaps/conformal_infer_cv.md @@ -0,0 +1,73 @@ +# bad inputs to conformal intervals + + 2 repeats were used. + i This method was developed for basic V-fold cross-validation. + i Interval coverage is unknown for multiple repeats. + +--- + + The data were resampled using Bootstrap sampling. + i This method was developed for V-fold cross-validation. + i Interval coverage is unknown for your resampling method. + +--- + + Code + basic_cv_obj + Output + Conformal inference via CV+ + preprocessor: formula + model: linear_reg (engine = lm) + number of models: 2 + training set size: 500 + + Use `predict(object, new_data, level)` to compute prediction intervals + +--- + + Code + int_conformal_cv(workflow()) + Condition + Error in `int_conformal_cv()`: + ! No known `int_conformal_cv()` methods for this type of object. + +--- + + Code + int_conformal_cv(dplyr::select(good_res, -.predictions)) + Condition + Error in `int_conformal_cv()`: + ! The output must contain a column called `.predictions` that contains the holdout predictions. See the documentation on the `save_pred` argument of the control function (e.g., `control_grid()` or `control_resamples()`, etc.). + +--- + + Code + int_conformal_cv(dplyr::select(good_res, -.extracts)) + Condition + Error in `int_conformal_cv()`: + ! The output must contain a column called `.extracts` that contains the fitted workflow objects. See the documentation on the `extract` argument of the control function (e.g., `control_grid()` or `control_resamples()`, etc.). + +--- + + Code + predict(basic_cv_obj, sim_new[, 3:5]) + Condition + Error in `map()`: + i In index: 1. + Caused by error in `hardhat::forge()`: + ! The required columns "predictor_01", "predictor_05", "predictor_06", "predictor_07", "predictor_08", "predictor_09", "predictor_10", "predictor_11", "predictor_12", "predictor_13", "predictor_14", "predictor_15", "predictor_16", "predictor_17", "predictor_18", "predictor_19", and "predictor_20" are missing. + +# conformal intervals + + Code + int_conformal_cv(grid_res, two_models) + Condition + Error in `int_conformal_cv()`: + ! The `parameters` argument selected 2 submodels. Only 1 should be selected. + +# group resampling to conformal CV intervals + + The data were resampled using Group 2-fold cross-validation. + i This method was developed for V-fold cross-validation. + i Interval coverage is unknown for your resampling method. + diff --git a/tests/testthat/_snaps/conformal-intervals.md b/tests/testthat/_snaps/conformal_infer_full.md similarity index 61% rename from tests/testthat/_snaps/conformal-intervals.md rename to tests/testthat/_snaps/conformal_infer_full.md index 16842a3..8ebad07 100644 --- a/tests/testthat/_snaps/conformal-intervals.md +++ b/tests/testthat/_snaps/conformal_infer_full.md @@ -68,65 +68,6 @@ Error in `hardhat::forge()`: ! The required columns "predictor_01", "predictor_02", "predictor_03", "predictor_04", "predictor_05", "predictor_06", "predictor_07", "predictor_08", "predictor_09", "predictor_10", "predictor_11", "predictor_12", "predictor_13", "predictor_14", "predictor_15", "predictor_16", "predictor_17", "predictor_18", "predictor_19", and "predictor_20" are missing. ---- - - 2 repeats were used. - i This method was developed for basic V-fold cross-validation. - i Interval coverage is unknown for multiple repeats. - ---- - - The data were resampled using Bootstrap sampling. - i This method was developed for V-fold cross-validation. - i Interval coverage is unknown for your resampling method. - ---- - - Code - basic_cv_obj - Output - Conformal inference via CV+ - preprocessor: formula - model: linear_reg (engine = lm) - number of models: 2 - training set size: 500 - - Use `predict(object, new_data, level)` to compute prediction intervals - ---- - - Code - int_conformal_cv(workflow()) - Condition - Error in `int_conformal_cv()`: - ! No known `int_conformal_cv()` methods for this type of object. - ---- - - Code - int_conformal_cv(dplyr::select(good_res, -.predictions)) - Condition - Error in `int_conformal_cv()`: - ! The output must contain a column called `.predictions` that contains the holdout predictions. See the documentation on the `save_pred` argument of the control function (e.g., `control_grid()` or `control_resamples()`, etc.). - ---- - - Code - int_conformal_cv(dplyr::select(good_res, -.extracts)) - Condition - Error in `int_conformal_cv()`: - ! The output must contain a column called `.extracts` that contains the fitted workflow objects. See the documentation on the `extract` argument of the control function (e.g., `control_grid()` or `control_resamples()`, etc.). - ---- - - Code - predict(basic_cv_obj, sim_new[, 3:5]) - Condition - Error in `map()`: - i In index: 1. - Caused by error in `hardhat::forge()`: - ! The required columns "predictor_01", "predictor_05", "predictor_06", "predictor_07", "predictor_08", "predictor_09", "predictor_10", "predictor_11", "predictor_12", "predictor_13", "predictor_14", "predictor_15", "predictor_16", "predictor_17", "predictor_18", "predictor_19", and "predictor_20" are missing. - --- Code @@ -163,14 +104,6 @@ Output ---- - - Code - int_conformal_cv(grid_res, two_models) - Condition - Error in `int_conformal_cv()`: - ! The `parameters` argument selected 2 submodels. Only 1 should be selected. - # conformal control Code @@ -197,9 +130,3 @@ Error in `control_conformal_full()`: ! `method` must be one of "iterative" or "grid", not "rock-paper-scissors". -# group resampling to conformal CV intervals - - The data were resampled using Group 2-fold cross-validation. - i This method was developed for V-fold cross-validation. - i Interval coverage is unknown for your resampling method. - diff --git a/tests/testthat/_snaps/conformal-intervals-quantile.md b/tests/testthat/_snaps/conformal_infer_quantile.md similarity index 100% rename from tests/testthat/_snaps/conformal-intervals-quantile.md rename to tests/testthat/_snaps/conformal_infer_quantile.md diff --git a/tests/testthat/_snaps/conformal-intervals-split.md b/tests/testthat/_snaps/conformal_infer_split.md similarity index 100% rename from tests/testthat/_snaps/conformal-intervals-split.md rename to tests/testthat/_snaps/conformal_infer_split.md diff --git a/tests/testthat/_snaps/make-class-pred.md b/tests/testthat/_snaps/make_class_pred.md similarity index 100% rename from tests/testthat/_snaps/make-class-pred.md rename to tests/testthat/_snaps/make_class_pred.md diff --git a/tests/testthat/_snaps/threshold-perf.md b/tests/testthat/_snaps/threshold_perf.md similarity index 100% rename from tests/testthat/_snaps/threshold-perf.md rename to tests/testthat/_snaps/threshold_perf.md diff --git a/tests/testthat/test-conformal-intervals.R b/tests/testthat/test-conformal_infer_cv.R similarity index 65% rename from tests/testthat/test-conformal-intervals.R rename to tests/testthat/test-conformal_infer_cv.R index 5233d2e..2544ac2 100644 --- a/tests/testthat/test-conformal-intervals.R +++ b/tests/testthat/test-conformal_infer_cv.R @@ -44,64 +44,6 @@ test_that("bad inputs to conformal intervals", { # ---------------------------------------------------------------------------- - set.seed(121212) - sim_cls_data <- sim_classification(100) - wflow_cls <- - workflow() |> - add_model(parsnip::logistic_reg()) |> - add_formula(class ~ .) |> - fit(sim_cls_data) - - sim_cls_new <- sim_classification(2) - - # ---------------------------------------------------------------------------- - - # When the gam for variance fails: - expect_snapshot( - error = TRUE, - int_conformal_full(wflow, sim_new) - ) - - expect_snapshot( - error = TRUE, - int_conformal_full( - wflow, - sim_data, - control = control_conformal_full(required_pkgs = "boop") - ) - ) - - basic_obj <- int_conformal_full(wflow, train_data = sim_data) - expect_snapshot(basic_obj) - expect_s3_class(basic_obj, "int_conformal_full") - - expect_snapshot( - error = TRUE, - int_conformal_full(workflow(), sim_new) - ) - - expect_snapshot( - error = TRUE, - int_conformal_full(wflow |> extract_fit_parsnip(), sim_new) - ) - - expect_snapshot( - error = TRUE, - int_conformal_full(wflow_cls, sim_cls_new) - ) - - expect_snapshot( - error = TRUE, - predict(basic_obj, sim_new[, 3:5]) - ) - - expect_snapshot( - error = TRUE, - int_conformal_full(wflow, train_data = sim_cls_data) - ) - - # ---------------------------------------------------------------------------- - basic_cv_obj <- int_conformal_cv(good_res) expect_snapshot_warning( @@ -132,10 +74,6 @@ test_that("bad inputs to conformal intervals", { error = TRUE, predict(basic_cv_obj, sim_new[, 3:5]) ) - - expect_snapshot( - probably:::get_root(try(stop("I made you stop"), silent = TRUE), control_conformal_full()) - ) }) test_that("conformal intervals", { @@ -168,16 +106,6 @@ test_that("conformal intervals", { set.seed(182) sim_new <- sim_regression(2) - ctrl_grid <- control_conformal_full(method = "grid", seed = 1) - basic_obj <- int_conformal_full(wflow, train_data = sim_data, control = ctrl_grid) - - ctrl_hard <- control_conformal_full( - progress = TRUE, seed = 1, - max_iter = 2, tolerance = 0.000001 - ) - smol_obj <- int_conformal_full(wflow_small, train_data = sim_small, control = ctrl_hard) - - ctrl <- control_resamples(save_pred = TRUE, extract = I) set.seed(382) cv <- vfold_cv(sim_data, v = 2) @@ -190,22 +118,6 @@ test_that("conformal intervals", { # ---------------------------------------------------------------------------- - expect_snapshot( - res_small <- predict(smol_obj, sim_new) - ) - expect_equal(names(res_small), c(".pred_lower", ".pred_upper")) - expect_equal(nrow(res_small), 2) - expect_true(mean(complete.cases(res_small)) < 1) - - # ---------------------------------------------------------------------------- - - res <- predict(basic_obj, sim_new[1, ]) - expect_equal(names(res), c(".pred_lower", ".pred_upper")) - expect_equal(nrow(res), 1) - expect_true(mean(complete.cases(res)) == 1) - - # ---------------------------------------------------------------------------- - cv_int <- int_conformal_cv(cv_res) cv_bounds <- predict(cv_int, sim_small) cv_bounds_90 <- predict(cv_int, sim_small, level = .9) @@ -230,15 +142,6 @@ test_that("conformal intervals", { all(grid_bounds$.pred_lower < grid_bounds_90$.pred_lower) ) - expect_identical( - required_pkgs(smol_obj), - c(required_pkgs(smol_obj$wflow), "probably") - ) - expect_identical( - required_pkgs(smol_obj, infra = FALSE), - required_pkgs(smol_obj$wflow, infra = FALSE) - ) - expect_identical( required_pkgs(grid_int), c(unique(unlist(map(grid_int$models, required_pkgs))), "probably") @@ -249,13 +152,6 @@ test_that("conformal intervals", { ) }) -test_that("conformal control", { - set.seed(1) - expect_snapshot(dput(control_conformal_full())) - expect_snapshot(dput(control_conformal_full(max_iter = 2))) - expect_snapshot(error = TRUE, control_conformal_full(method = "rock-paper-scissors")) -}) - test_that("group resampling to conformal CV intervals", { skip_if_not_installed("modeldata") @@ -287,6 +183,4 @@ test_that("group resampling to conformal CV intervals", { fit_resamples(group_folds, control = ctrl) expect_snapshot_warning(int_conformal_cv(group_nnet_rs)) - }) - diff --git a/tests/testthat/test-conformal_infer_full.R b/tests/testthat/test-conformal_infer_full.R new file mode 100644 index 0000000..236c9c8 --- /dev/null +++ b/tests/testthat/test-conformal_infer_full.R @@ -0,0 +1,192 @@ +test_that("bad inputs to conformal intervals", { + skip_if_not_installed("modeldata") + skip_if_not_installed("nnet") + + # ---------------------------------------------------------------------------- + + suppressPackageStartupMessages(library(workflows)) + suppressPackageStartupMessages(library(modeldata)) + suppressPackageStartupMessages(library(purrr)) + suppressPackageStartupMessages(library(rsample)) + suppressPackageStartupMessages(library(tune)) + suppressPackageStartupMessages(library(dplyr)) + + # ---------------------------------------------------------------------------- + + set.seed(111) + sim_data <- sim_regression(500) + wflow <- + workflow() |> + add_model(parsnip::linear_reg()) |> + add_formula(outcome ~ .) |> + fit(sim_data) + + set.seed(182) + sim_new <- sim_regression(2) + + + ctrl <- control_resamples(save_pred = TRUE, extract = I) + + set.seed(382) + cv <- vfold_cv(sim_data, v = 2) + good_res <- + parsnip::linear_reg() |> fit_resamples(outcome ~ ., cv, control = ctrl) + + set.seed(382) + cv <- vfold_cv(sim_data, v = 2, repeats = 2) + rep_res <- + parsnip::linear_reg() |> fit_resamples(outcome ~ ., cv, control = ctrl) + + set.seed(382) + bt <- bootstraps(sim_data, times = 2) + bt_res <- + parsnip::linear_reg() |> fit_resamples(outcome ~ ., bt, control = ctrl) + + # ---------------------------------------------------------------------------- + + set.seed(121212) + sim_cls_data <- sim_classification(100) + wflow_cls <- + workflow() |> + add_model(parsnip::logistic_reg()) |> + add_formula(class ~ .) |> + fit(sim_cls_data) + + sim_cls_new <- sim_classification(2) + + # ---------------------------------------------------------------------------- + + # When the gam for variance fails: + expect_snapshot( + error = TRUE, + int_conformal_full(wflow, sim_new) + ) + + expect_snapshot( + error = TRUE, + int_conformal_full( + wflow, + sim_data, + control = control_conformal_full(required_pkgs = "boop") + ) + ) + + basic_obj <- int_conformal_full(wflow, train_data = sim_data) + expect_snapshot(basic_obj) + expect_s3_class(basic_obj, "int_conformal_full") + + expect_snapshot( + error = TRUE, + int_conformal_full(workflow(), sim_new) + ) + + expect_snapshot( + error = TRUE, + int_conformal_full(wflow |> extract_fit_parsnip(), sim_new) + ) + + expect_snapshot( + error = TRUE, + int_conformal_full(wflow_cls, sim_cls_new) + ) + + expect_snapshot( + error = TRUE, + predict(basic_obj, sim_new[, 3:5]) + ) + + expect_snapshot( + error = TRUE, + int_conformal_full(wflow, train_data = sim_cls_data) + ) + + expect_snapshot( + probably:::get_root(try(stop("I made you stop"), silent = TRUE), control_conformal_full()) + ) +}) + +test_that("conformal intervals", { + skip_if_not_installed("modeldata") + skip_on_cran() + + # ---------------------------------------------------------------------------- + + library(workflows) + library(modeldata) + + # ---------------------------------------------------------------------------- + + set.seed(111) + sim_data <- sim_regression(500) + sim_small <- sim_data[1:25, ] + + wflow <- + workflow() |> + add_model(parsnip::linear_reg()) |> + add_formula(outcome ~ .) |> + fit(sim_data) + + wflow_small <- + workflow() |> + add_model(parsnip::linear_reg()) |> + add_formula(outcome ~ .) |> + fit(sim_small) + + set.seed(182) + sim_new <- sim_regression(2) + + ctrl_grid <- control_conformal_full(method = "grid", seed = 1) + basic_obj <- int_conformal_full(wflow, train_data = sim_data, control = ctrl_grid) + + ctrl_hard <- control_conformal_full( + progress = TRUE, seed = 1, + max_iter = 2, tolerance = 0.000001 + ) + smol_obj <- int_conformal_full(wflow_small, train_data = sim_small, control = ctrl_hard) + + + ctrl <- control_resamples(save_pred = TRUE, extract = I) + set.seed(382) + cv <- vfold_cv(sim_data, v = 2) + cv_res <- + parsnip::linear_reg() |> fit_resamples(outcome ~ ., cv, control = ctrl) + grid_res <- + parsnip::mlp(penalty = tune()) |> + parsnip::set_mode("regression") |> + tune_grid(outcome ~ ., cv, grid = 2, control = ctrl) + + # ---------------------------------------------------------------------------- + + expect_snapshot( + res_small <- predict(smol_obj, sim_new) + ) + expect_equal(names(res_small), c(".pred_lower", ".pred_upper")) + expect_equal(nrow(res_small), 2) + expect_true(mean(complete.cases(res_small)) < 1) + + # ---------------------------------------------------------------------------- + + res <- predict(basic_obj, sim_new[1, ]) + expect_equal(names(res), c(".pred_lower", ".pred_upper")) + expect_equal(nrow(res), 1) + expect_true(mean(complete.cases(res)) == 1) + + # ---------------------------------------------------------------------------- + + expect_identical( + required_pkgs(smol_obj), + c(required_pkgs(smol_obj$wflow), "probably") + ) + expect_identical( + required_pkgs(smol_obj, infra = FALSE), + required_pkgs(smol_obj$wflow, infra = FALSE) + ) +}) + +test_that("conformal control", { + set.seed(1) + expect_snapshot(dput(control_conformal_full())) + expect_snapshot(dput(control_conformal_full(max_iter = 2))) + expect_snapshot(error = TRUE, control_conformal_full(method = "rock-paper-scissors")) +}) + diff --git a/tests/testthat/test-conformal-intervals-quantile.R b/tests/testthat/test-conformal_infer_quantile.R similarity index 100% rename from tests/testthat/test-conformal-intervals-quantile.R rename to tests/testthat/test-conformal_infer_quantile.R diff --git a/tests/testthat/test-conformal-intervals-split.R b/tests/testthat/test-conformal_infer_split.R similarity index 100% rename from tests/testthat/test-conformal-intervals-split.R rename to tests/testthat/test-conformal_infer_split.R diff --git a/tests/testthat/test-make-class-pred.R b/tests/testthat/test-make_class_pred.R similarity index 100% rename from tests/testthat/test-make-class-pred.R rename to tests/testthat/test-make_class_pred.R diff --git a/tests/testthat/test-threshold-perf.R b/tests/testthat/test-threshold_perf.R similarity index 100% rename from tests/testthat/test-threshold-perf.R rename to tests/testthat/test-threshold_perf.R