diff --git a/.Rbuildignore b/.Rbuildignore index b8fca22..e857831 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -12,3 +12,5 @@ ^pkgdown$ ^CRAN-SUBMISSION$ ^man-roxygen$ +^[.]?air[.]toml$ +^\.vscode$ diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..344f76e --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "Posit.air-vscode" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..a9f69fe --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,10 @@ +{ + "[r]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "Posit.air-vscode" + }, + "[quarto]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "quarto.quarto" + } +} diff --git a/R/bound_prediction.R b/R/bound_prediction.R index 7513ac5..4c05f0c 100644 --- a/R/bound_prediction.R +++ b/R/bound_prediction.R @@ -15,13 +15,19 @@ #' #' bound_prediction(solubility_test, lower_limit = -1) #' @export -bound_prediction <- function(x, lower_limit = -Inf, upper_limit = Inf, - call = rlang::current_env()) { +bound_prediction <- function( + x, + lower_limit = -Inf, + upper_limit = Inf, + call = rlang::current_env() +) { check_data_frame(x, call = call) if (!any(names(x) == ".pred")) { - cli::cli_abort("The argument {.arg x} should have a column named {.code .pred}.", - call = call) + cli::cli_abort( + "The argument {.arg x} should have a column named {.code .pred}.", + call = call + ) } if (!is.numeric(x$.pred)) { cli::cli_abort("Column {.code .pred} should be numeric.", call = call) @@ -39,4 +45,3 @@ bound_prediction <- function(x, lower_limit = -Inf, upper_limit = Inf, } x } - diff --git a/R/cal-apply-binary.R b/R/cal-apply-binary.R index ef23499..9116982 100644 --- a/R/cal-apply-binary.R +++ b/R/cal-apply-binary.R @@ -5,10 +5,12 @@ cal_apply_binary <- function(object, .data, pred_class) { } #' @export -cal_apply_binary.cal_estimate_logistic <- function(object, - .data, - pred_class = NULL, - ...) { +cal_apply_binary.cal_estimate_logistic <- function( + object, + .data, + pred_class = NULL, + ... +) { apply_model_predict( object = object, .data = .data @@ -16,10 +18,12 @@ cal_apply_binary.cal_estimate_logistic <- function(object, } #' @export -cal_apply_binary.cal_estimate_logistic_spline <- function(object, - .data, - pred_class = NULL, - ...) { +cal_apply_binary.cal_estimate_logistic_spline <- function( + object, + .data, + pred_class = NULL, + ... +) { apply_model_predict( object = object, .data = .data diff --git a/R/cal-apply-impl.R b/R/cal-apply-impl.R index bed8d82..e8cce30 100644 --- a/R/cal-apply-impl.R +++ b/R/cal-apply-impl.R @@ -107,7 +107,10 @@ apply_beta_column <- function(.data, est_filter, estimates) { } ret <- - purrr::imap(estimates, ~ apply_beta_single(model = .x, df = df, est_name = .y)) + purrr::imap( + estimates, + ~ apply_beta_single(model = .x, df = df, est_name = .y) + ) names_ret <- names(ret) for (i in seq_along(names_ret)) { diff --git a/R/cal-apply.R b/R/cal-apply.R index 22b94d2..a836698 100644 --- a/R/cal-apply.R +++ b/R/cal-apply.R @@ -31,22 +31,26 @@ #' #' cal_apply(segment_logistic, w_calibration) #' @export -cal_apply <- function(.data, - object, - pred_class = NULL, - parameters = NULL, - ...) { +cal_apply <- function( + .data, + object, + pred_class = NULL, + parameters = NULL, + ... +) { rlang::check_dots_empty() UseMethod("cal_apply") } #' @export #' @rdname cal_apply -cal_apply.data.frame <- function(.data, - object, - pred_class = NULL, - parameters = NULL, - ...) { +cal_apply.data.frame <- function( + .data, + object, + pred_class = NULL, + parameters = NULL, + ... +) { cal_pkg_check(required_pkgs(object)) stop_null_parameters(parameters) @@ -60,11 +64,13 @@ cal_apply.data.frame <- function(.data, #' @export #' @rdname cal_apply -cal_apply.tune_results <- function(.data, - object, - pred_class = NULL, - parameters = NULL, - ...) { +cal_apply.tune_results <- function( + .data, + object, + pred_class = NULL, + parameters = NULL, + ... +) { cal_pkg_check(required_pkgs(object)) if (!(".predictions" %in% colnames(.data))) { @@ -99,11 +105,13 @@ cal_apply.tune_results <- function(.data, #' @export #' @rdname cal_apply -cal_apply.cal_object <- function(.data, - object, - pred_class = NULL, - parameters = NULL, - ...) { +cal_apply.cal_object <- function( + .data, + object, + pred_class = NULL, + parameters = NULL, + ... +) { if ("data.frame" %in% class(object)) { cli::cli_abort( c( @@ -140,10 +148,12 @@ cal_adjust.cal_estimate_isotonic_boot <- function(object, .data, pred_class) { } #' @export -cal_adjust.cal_estimate_beta <- function(object, - .data, - pred_class = NULL, - ...) { +cal_adjust.cal_estimate_beta <- function( + object, + .data, + pred_class = NULL, + ... +) { apply_beta_impl( object = object, .data = .data @@ -182,11 +192,13 @@ cal_adjust.cal_estimate_none <- function(object, .data, pred_class) { .data } -cal_adjust_update <- function(.data, - object, - pred_class = NULL, - parameters = NULL, - ...) { +cal_adjust_update <- function( + .data, + object, + pred_class = NULL, + parameters = NULL, + ... +) { if (object$type != "regression") { pred_class <- enquo(pred_class) } else { diff --git a/R/cal-estimate-beta.R b/R/cal-estimate-beta.R index f1bcf19..8830a20 100644 --- a/R/cal-estimate-beta.R +++ b/R/cal-estimate-beta.R @@ -22,13 +22,13 @@ #' } #' @export cal_estimate_beta <- function( - .data, - truth = NULL, - shape_params = 2, - location_params = 1, - estimate = dplyr::starts_with(".pred_"), - parameters = NULL, - ... + .data, + truth = NULL, + shape_params = 2, + location_params = 1, + estimate = dplyr::starts_with(".pred_"), + parameters = NULL, + ... ) { UseMethod("cal_estimate_beta") } @@ -36,14 +36,14 @@ cal_estimate_beta <- function( #' @export #' @rdname cal_estimate_beta cal_estimate_beta.data.frame <- function( - .data, - truth = NULL, - shape_params = 2, - location_params = 1, - estimate = dplyr::starts_with(".pred_"), - parameters = NULL, - ..., - .by = NULL + .data, + truth = NULL, + shape_params = 2, + location_params = 1, + estimate = dplyr::starts_with(".pred_"), + parameters = NULL, + ..., + .by = NULL ) { stop_null_parameters(parameters) @@ -70,13 +70,13 @@ cal_estimate_beta.data.frame <- function( #' @export #' @rdname cal_estimate_beta cal_estimate_beta.tune_results <- function( - .data, - truth = NULL, - shape_params = 2, - location_params = 1, - estimate = dplyr::starts_with(".pred_"), - parameters = NULL, - ... + .data, + truth = NULL, + shape_params = 2, + location_params = 1, + estimate = dplyr::starts_with(".pred_"), + parameters = NULL, + ... ) { info <- get_tune_data(.data, parameters) @@ -96,13 +96,13 @@ cal_estimate_beta.tune_results <- function( #' @export #' @rdname cal_estimate_beta cal_estimate_beta.grouped_df <- function( - .data, - truth = NULL, - shape_params = 2, - location_params = 1, - estimate = NULL, - parameters = NULL, - ... + .data, + truth = NULL, + shape_params = 2, + location_params = 1, + estimate = NULL, + parameters = NULL, + ... ) { abort_if_grouped_df() } @@ -137,12 +137,12 @@ beta_fit_over_groups <- function(info, shape_params, location_params, ...) { fit_all_beta_models <- function( - .data, - truth = NULL, - shape = 2, - location = 1, - estimate = NULL, - ... + .data, + truth = NULL, + shape = 2, + location = 1, + estimate = NULL, + ... ) { lvls <- levels(.data[[truth]]) num_lvls <- length(lvls) @@ -176,12 +176,12 @@ fit_all_beta_models <- function( fit_beta_model <- function( - .data, - truth = NULL, - shape = 2, - location = 1, - estimate = NULL, - ... + .data, + truth = NULL, + shape = 2, + location = 1, + estimate = NULL, + ... ) { outcome_data <- .data[[truth]] lvls <- levels(outcome_data) @@ -255,6 +255,8 @@ check_cal_groups <- function(group, .data, call = rlang::env_parent()) { #' @export print.betacal <- function(x, ...) { - cli::cli_inform("Beta calibration ({x$parameters}) using {x$model$df.null} samples") + cli::cli_inform( + "Beta calibration ({x$parameters}) using {x$model$df.null} samples" + ) invisible(x) } diff --git a/R/cal-estimate-isotonic.R b/R/cal-estimate-isotonic.R index 6e9914d..537a0e0 100644 --- a/R/cal-estimate-isotonic.R +++ b/R/cal-estimate-isotonic.R @@ -175,7 +175,6 @@ cal_estimate_isotonic_boot.data.frame <- function( source_class = cal_class_name(.data), additional_classes = "cal_estimate_isotonic_boot" ) - } #' @export @@ -188,7 +187,6 @@ cal_estimate_isotonic_boot.tune_results <- function( parameters = NULL, ... ) { - info <- get_tune_data(.data, parameters) model <- isoreg_fit_over_groups(info, times = times, ...) @@ -202,7 +200,6 @@ cal_estimate_isotonic_boot.tune_results <- function( source_class = cal_class_name(.data), additional_classes = "cal_estimate_isotonic_boot" ) - } #' @export @@ -241,13 +238,12 @@ isoreg_fit_over_groups <- function(info, times = 1, ...) { } fit_ensemble_isoreg_models <- function( - .data, - truth = NULL, - estimate = NULL, - times = 1, - ... + .data, + truth = NULL, + estimate = NULL, + times = 1, + ... ) { - is_sampled <- times > 1 iso_models <- purrr::map( @@ -269,11 +265,11 @@ fit_ensemble_isoreg_models <- function( } fit_all_isoreg_models <- function( - .data, - truth = NULL, - estimate = NULL, - sampled = FALSE, - ... + .data, + truth = NULL, + estimate = NULL, + sampled = FALSE, + ... ) { lvls <- levels(.data[[truth]]) num_lvls <- length(lvls) @@ -310,7 +306,6 @@ fit_isoreg_model <- function( sampled = FALSE, ... ) { - estimate <- estimate[1] sorted_data <- dplyr::arrange(.data, !!rlang::syms(estimate)) @@ -322,8 +317,8 @@ fit_isoreg_model <- function( ) } - x <- sorted_data[[ estimate ]] - y <- sorted_data[[ truth ]] + x <- sorted_data[[estimate]] + y <- sorted_data[[truth]] if (is.factor(y)) { lvls <- levels(y) diff --git a/R/cal-estimate-linear.R b/R/cal-estimate-linear.R index 577f7f2..6142d1b 100644 --- a/R/cal-estimate-linear.R +++ b/R/cal-estimate-linear.R @@ -65,25 +65,29 @@ #' These methods estimate the relationship in the unmodified predicted values #' and then remove that trend when [cal_apply()] is invoked. #' @export -cal_estimate_linear <- function(.data, - truth = NULL, - estimate = dplyr::matches("^.pred$"), - smooth = TRUE, - parameters = NULL, - ..., - .by = NULL) { +cal_estimate_linear <- function( + .data, + truth = NULL, + estimate = dplyr::matches("^.pred$"), + smooth = TRUE, + parameters = NULL, + ..., + .by = NULL +) { UseMethod("cal_estimate_linear") } #' @export #' @rdname cal_estimate_linear -cal_estimate_linear.data.frame <- function(.data, - truth = NULL, - estimate = dplyr::matches("^.pred$"), - smooth = TRUE, - parameters = NULL, - ..., - .by = NULL) { +cal_estimate_linear.data.frame <- function( + .data, + truth = NULL, + estimate = dplyr::matches("^.pred$"), + smooth = TRUE, + parameters = NULL, + ..., + .by = NULL +) { stop_null_parameters(parameters) info <- get_prediction_data( @@ -114,17 +118,18 @@ cal_estimate_linear.data.frame <- function(.data, additional_class = additional_class, source_class = cal_class_name(.data) ) - } #' @export #' @rdname cal_estimate_linear -cal_estimate_linear.tune_results <- function(.data, - truth = NULL, - estimate = dplyr::matches("^.pred$"), - smooth = TRUE, - parameters = NULL, - ...) { +cal_estimate_linear.tune_results <- function( + .data, + truth = NULL, + estimate = dplyr::matches("^.pred$"), + smooth = TRUE, + parameters = NULL, + ... +) { info <- get_tune_data(.data, parameters) model_fit <- lin_reg_fit_over_groups(info, smooth, ...) @@ -152,12 +157,14 @@ cal_estimate_linear.tune_results <- function(.data, #' @export #' @rdname cal_estimate_linear -cal_estimate_linear.grouped_df <- function(.data, - truth = NULL, - estimate = NULL, - smooth = TRUE, - parameters = NULL, - ...) { +cal_estimate_linear.grouped_df <- function( + .data, + truth = NULL, + estimate = NULL, + smooth = TRUE, + parameters = NULL, + ... +) { abort_if_grouped_df() } @@ -198,7 +205,9 @@ fit_regression_model <- function(.data, smooth, estimate, outcome, ...) { lin_reg_fit_over_groups <- function(info, smooth = TRUE, ...) { if (length(info$levels) == 2) { - cli::cli_abort("This function is meant to be used with multi-class outcomes only.") + cli::cli_abort( + "This function is meant to be used with multi-class outcomes only." + ) } grp_df <- make_group_df(info$predictions, group = info$group) diff --git a/R/cal-estimate-logistic.R b/R/cal-estimate-logistic.R index 9b6d7e6..ee3b52c 100644 --- a/R/cal-estimate-logistic.R +++ b/R/cal-estimate-logistic.R @@ -94,7 +94,6 @@ cal_estimate_logistic.data.frame <- function( source_class = cal_class_name(.data), type = "binary" ) - } #' @export @@ -107,7 +106,6 @@ cal_estimate_logistic.tune_results <- function( parameters = NULL, ... ) { - info <- get_tune_data(.data, parameters) model <- logistic_fit_over_groups(info, smooth, ...) @@ -172,11 +170,10 @@ fit_logistic_model <- function(.data, smooth, estimate, outcome, ...) { } else { # TODO check for failures model <- glm(f, data = .data, family = "binomial", ...) - -} + } model <- butcher::butcher(model) model - } +} logistic_fit_over_groups <- function(info, smooth = TRUE, ...) { if (length(info$levels) > 2) { @@ -202,4 +199,3 @@ logistic_fit_over_groups <- function(info, smooth = TRUE, ...) { purrr::map2(fits, fltrs, ~ list(filter = .y, estimate = .x)) } - diff --git a/R/cal-estimate-multinomial.R b/R/cal-estimate-multinomial.R index 694e7c9..33626bd 100644 --- a/R/cal-estimate-multinomial.R +++ b/R/cal-estimate-multinomial.R @@ -46,25 +46,29 @@ #' cal_plot_windowed(new_test_pred, truth = class, window_size = 0.1, step_size = 0.03) #' #' @export -cal_estimate_multinomial <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - smooth = TRUE, - parameters = NULL, - ...) { +cal_estimate_multinomial <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + smooth = TRUE, + parameters = NULL, + ... +) { UseMethod("cal_estimate_multinomial") } #' @export #' @rdname cal_estimate_multinomial cal_estimate_multinomial.data.frame <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - smooth = TRUE, - parameters = NULL, - ..., - .by = NULL) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + smooth = TRUE, + parameters = NULL, + ..., + .by = NULL + ) { stop_null_parameters(parameters) info <- get_prediction_data( @@ -78,8 +82,10 @@ cal_estimate_multinomial.data.frame <- if (smooth) { method <- "Generalized additive model calibration" - additional_class <- c("cal_estimate_multinomial_spline", - "cal_estimate_multinomial") + additional_class <- c( + "cal_estimate_multinomial_spline", + "cal_estimate_multinomial" + ) } else { method <- "Multinomial regression calibration" additional_class <- "cal_estimate_multinomial" @@ -100,13 +106,14 @@ cal_estimate_multinomial.data.frame <- #' @export #' @rdname cal_estimate_multinomial cal_estimate_multinomial.tune_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - smooth = TRUE, - parameters = NULL, - ...) { - + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + smooth = TRUE, + parameters = NULL, + ... + ) { info <- get_tune_data(.data, parameters) model <- mtnml_fit_over_groups(info, smooth, ...) @@ -131,12 +138,14 @@ cal_estimate_multinomial.tune_results <- #' @export #' @rdname cal_estimate_multinomial -cal_estimate_multinomial.grouped_df <- function(.data, - truth = NULL, - estimate = NULL, - smooth = TRUE, - parameters = NULL, - ...) { +cal_estimate_multinomial.grouped_df <- function( + .data, + truth = NULL, + estimate = NULL, + smooth = TRUE, + parameters = NULL, + ... +) { abort_if_grouped_df() } @@ -194,7 +203,9 @@ fit_multinomial_model <- function(.data, smooth, estimate, outcome, ...) { mtnml_fit_over_groups <- function(info, smooth = TRUE, ...) { if (length(info$levels) == 2) { - cli::cli_abort("This function is meant to be used with multi-class outcomes only.") + cli::cli_abort( + "This function is meant to be used with multi-class outcomes only." + ) } grp_df <- make_group_df(info$predictions, group = info$group) diff --git a/R/cal-estimate-none.R b/R/cal-estimate-none.R index 1dbc669..e7a1301 100644 --- a/R/cal-estimate-none.R +++ b/R/cal-estimate-none.R @@ -31,29 +31,35 @@ #' segment_logistic #' ) #' @export -cal_estimate_none <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - parameters = NULL, - ...) { +cal_estimate_none <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + parameters = NULL, + ... +) { UseMethod("cal_estimate_none") } #' @export #' @rdname cal_estimate_none -cal_estimate_none.data.frame <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - parameters = NULL, - ..., - .by = NULL) { +cal_estimate_none.data.frame <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + parameters = NULL, + ..., + .by = NULL +) { stop_null_parameters(parameters) rlang::check_dots_empty() - info <- get_prediction_data(.data, - truth = {{ truth }}, - estimate = {{ estimate }}, - .by = {{ .by }}) + info <- get_prediction_data( + .data, + truth = {{ truth }}, + estimate = {{ estimate }}, + .by = {{ .by }} + ) model <- nothing_over_groups(info, ...) @@ -77,11 +83,13 @@ cal_estimate_none.data.frame <- function(.data, #' @export #' @rdname cal_estimate_none -cal_estimate_none.tune_results <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - parameters = NULL, - ...) { +cal_estimate_none.tune_results <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + parameters = NULL, + ... +) { rlang::check_dots_empty() info <- get_tune_data(.data, parameters) @@ -107,17 +115,19 @@ cal_estimate_none.tune_results <- function(.data, #' @export #' @rdname cal_estimate_none -cal_estimate_none.grouped_df <- function(.data, - truth = NULL, - estimate = NULL, - parameters = NULL, - ...) { +cal_estimate_none.grouped_df <- function( + .data, + truth = NULL, + estimate = NULL, + parameters = NULL, + ... +) { abort_if_grouped_df() } #------------------------------ Implementation --------------------------------- -nothing_over_groups <- function(info, ...) { +nothing_over_groups <- function(info, ...) { grp_df <- make_group_df(info$predictions, group = info$group) nst_df <- vctrs::vec_split(x = info$predictions, by = grp_df) fltrs <- make_cal_filters(nst_df$key) @@ -153,4 +163,3 @@ print.no_calibration <- function(x, ...) { required_pkgs.cal_estimate_none <- function(x, ...) { c("probably") } - diff --git a/R/cal-estimate-utils.R b/R/cal-estimate-utils.R index 17562ae..2db341e 100644 --- a/R/cal-estimate-utils.R +++ b/R/cal-estimate-utils.R @@ -22,7 +22,8 @@ print.cal_regression <- function(x, ...) { print_cls_cal <- function(x, upv = FALSE, ...) { print_type <- - switch(x$type, + switch( + x$type, "binary" = "Binary", "multiclass" = "Multiclass", "one_vs_all" = "Multiclass (1 v All)", @@ -30,11 +31,13 @@ print_cls_cal <- function(x, upv = FALSE, ...) { NA_character_ ) - cli::cli_div(theme = list( - span.val0 = list(color = "blue"), - span.val1 = list(color = "yellow"), - span.val2 = list(color = "darkgreen") - )) + cli::cli_div( + theme = list( + span.val0 = list(color = "blue"), + span.val1 = list(color = "yellow"), + span.val2 = list(color = "darkgreen") + ) + ) rows <- prettyNum(x$rows, ",") cli::cli_h3("Probability Calibration") cli::cli_text("Method: {.val2 {x$method}}") @@ -65,11 +68,13 @@ print_cls_cal <- function(x, upv = FALSE, ...) { print_reg_cal <- function(x, upv = FALSE, ...) { - cli::cli_div(theme = list( - span.val0 = list(color = "blue"), - span.val1 = list(color = "yellow"), - span.val2 = list(color = "darkgreen") - )) + cli::cli_div( + theme = list( + span.val0 = list(color = "blue"), + span.val1 = list(color = "yellow"), + span.val2 = list(color = "darkgreen") + ) + ) rows <- prettyNum(x$rows, ",") cli::cli_h3("Regression Calibration") cli::cli_text("Method: {.val2 {x$method}}") @@ -121,7 +126,6 @@ cal_class_name.rset <- function(x) { # ------------------------------- Data Ingestion ------------------------------- - get_tune_data <- function(x, parameters = NULL) { .data <- collect_predictions( x, @@ -161,10 +165,10 @@ get_tune_data <- function(x, parameters = NULL) { } get_prediction_data <- function( - .data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - .by = NULL + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + .by = NULL ) { if (!inherits(.data, "tbl_df")) { .data <- dplyr::as_tibble(.data) @@ -176,7 +180,7 @@ get_prediction_data <- function( # So that we ignore non-numeric columns that are accidentally selected such # as `.pred_class` - is_num_est <- purrr::map_lgl(.data[,estimate], is.numeric) + is_num_est <- purrr::map_lgl(.data[, estimate], is.numeric) estimate <- estimate[is_num_est] lvls <- levels(.data[[truth]]) @@ -231,7 +235,6 @@ check_tm_format <- function(estimate, lvls) { tm_nms <- paste0(".pred_", lvls) if (identical(sort(estimate), sort(tm_nms))) { - estimate <- tm_nms } estimate @@ -239,14 +242,15 @@ check_tm_format <- function(estimate, lvls) { # ------------------------------- Utils ---------------------------------------- -as_regression_cal_object <- function(estimate, - truth, - levels, - method, - rows, - additional_class = NULL, - source_class = NULL) { - +as_regression_cal_object <- function( + estimate, + truth, + levels, + method, + rows, + additional_class = NULL, + source_class = NULL +) { as_cal_object( estimate = estimate, truth = truth, @@ -259,14 +263,16 @@ as_regression_cal_object <- function(estimate, ) } -as_cal_object <- function(estimate, - truth, - levels, - method, - rows, - additional_classes = NULL, - source_class = NULL, - type = NULL) { +as_cal_object <- function( + estimate, + truth, + levels, + method, + rows, + additional_classes = NULL, + source_class = NULL, + type = NULL +) { if (length(levels) == 1) { type <- "regression" obj_class <- "cal_regression" @@ -317,7 +323,10 @@ split_dplyr_groups <- function(.data) { grp_keys <- purrr::map(grp_keys, as.character) grp_var <- .data |> dplyr::group_vars() grp_data <- .data |> tidyr::nest() - grp_filters <- purrr::map(grp_keys[[1]], ~ expr(!!parse_expr(grp_var) == !!.x)) + grp_filters <- purrr::map( + grp_keys[[1]], + ~ expr(!!parse_expr(grp_var) == !!.x) + ) grp_n <- purrr::map_int(grp_data$data, nrow) res <- vector(mode = "list", length = length(grp_filters)) for (i in seq_along(res)) { @@ -333,7 +342,9 @@ split_dplyr_groups <- function(.data) { stop_null_parameters <- function(x) { if (!is.null(x)) { - cli::cli_abort("The {.arg parameters} argument is only valid for {.code tune_results}.") + cli::cli_abort( + "The {.arg parameters} argument is only valid for {.code tune_results}." + ) } } @@ -389,7 +400,7 @@ make_cal_filters <- function(key) { if (i == 1) { res <- tmp } else { - res <- purrr::map2(res, tmp, ~ rlang::expr(!!.x & !!.y)) + res <- purrr::map2(res, tmp, ~ rlang::expr(!!.x & !!.y)) } } @@ -434,7 +445,12 @@ multinomial_f_from_str <- function(y, x) { res } -turn_off_smooth_if_too_few_unique <- function(.data, estimate, smooth, min_vals = 10) { +turn_off_smooth_if_too_few_unique <- function( + .data, + estimate, + smooth, + min_vals = 10 +) { predictors <- .data[, estimate] if (smooth) { n_unique <- purrr::map_int(predictors, vctrs::vec_unique_count) @@ -452,7 +468,7 @@ turn_off_smooth_if_too_few_unique <- function(.data, estimate, smooth, min_vals # ------------------------------ 1 versus all helpers -------------------------- fit_over_classes <- function(.fn, .data, truth, estimate, ...) { - lvls <- levels(.data[[ truth ]]) + lvls <- levels(.data[[truth]]) prob_cols <- estimate res <- purrr::map2( @@ -470,12 +486,11 @@ fit_over_classes <- function(.fn, .data, truth, estimate, ...) { } fit_1_vs_all <- function(class, prob_col, .fn, .data, truth, estimate, ...) { - # Redefine the outcome class as the current class level - outcome <- .data[[ truth ]] + outcome <- .data[[truth]] new_class <- ifelse(outcome == class, class, ".other") new_class <- factor(new_class, levels = c(class, ".other")) - .data[[ truth ]] <- new_class + .data[[truth]] <- new_class res <- .fn(.data, truth = truth, estimate = prob_col, ...) res diff --git a/R/cal-plot-breaks.R b/R/cal-plot-breaks.R index 73569ea..e5b6490 100644 --- a/R/cal-plot-breaks.R +++ b/R/cal-plot-breaks.R @@ -89,32 +89,36 @@ #' theme(legend.position = "") #' @seealso [cal_plot_logistic()], [cal_plot_windowed()] #' @export -cal_plot_breaks <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - num_breaks = 10, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_breaks <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + num_breaks = 10, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ... +) { UseMethod("cal_plot_breaks") } #' @export #' @rdname cal_plot_breaks -cal_plot_breaks.data.frame <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - num_breaks = 10, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ..., - .by = NULL) { +cal_plot_breaks.data.frame <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + num_breaks = 10, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ..., + .by = NULL +) { group <- get_group_argument({{ .by }}, .data) .data <- dplyr::group_by(.data, dplyr::across({{ group }})) @@ -134,16 +138,18 @@ cal_plot_breaks.data.frame <- function(.data, } #' @export #' @rdname cal_plot_breaks -cal_plot_breaks.tune_results <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - num_breaks = 10, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_breaks.tune_results <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + num_breaks = 10, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ... +) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, @@ -169,34 +175,42 @@ cal_plot_breaks.tune_results <- function(.data, #' @export #' @rdname cal_plot_breaks -cal_plot_breaks.grouped_df <- function(.data, - truth = NULL, - estimate = NULL, - num_breaks = 10, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_breaks.grouped_df <- function( + .data, + truth = NULL, + estimate = NULL, + num_breaks = 10, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ... +) { abort_if_grouped_df() } #--------------------------- >> Implementation --------------------------------- -cal_plot_breaks_impl <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - group = NULL, - num_breaks = 10, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - is_tune_results = FALSE, - ...) { - rlang::arg_match0(event_level, c("auto", "first", "second"), error_call = NULL) +cal_plot_breaks_impl <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + group = NULL, + num_breaks = 10, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + is_tune_results = FALSE, + ... +) { + rlang::arg_match0( + event_level, + c("auto", "first", "second"), + error_call = NULL + ) truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) @@ -269,27 +283,31 @@ cal_plot_breaks_impl <- function(.data, #' @rdname cal_binary_tables #' @export #' @keywords internal -.cal_table_breaks <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - num_breaks = 10, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_breaks <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + num_breaks = 10, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { UseMethod(".cal_table_breaks") } #' @export #' @keywords internal -.cal_table_breaks.data.frame <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - num_breaks = 10, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_breaks.data.frame <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + num_breaks = 10, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { .cal_table_breaks_impl( .data = .data, truth = {{ truth }}, @@ -303,14 +321,16 @@ cal_plot_breaks_impl <- function(.data, #' @export #' @keywords internal -.cal_table_breaks.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - num_breaks = 10, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_breaks.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + num_breaks = 10, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, @@ -331,14 +351,16 @@ cal_plot_breaks_impl <- function(.data, } #--------------------------- >> Implementation --------------------------------- -.cal_table_breaks_impl <- function(.data, - truth, - estimate, - group, - num_breaks = 10, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_breaks_impl <- function( + .data, + truth, + estimate, + group, + num_breaks = 10, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) @@ -373,14 +395,16 @@ cal_plot_breaks_impl <- function(.data, res } -.cal_table_breaks_grp <- function(.data, - truth, - group, - num_breaks = 10, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - levels, - ...) { +.cal_table_breaks_grp <- function( + .data, + truth, + group, + num_breaks = 10, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + levels, + ... +) { side <- seq(0, 1, by = 1 / num_breaks) cuts <- list( diff --git a/R/cal-plot-logistic.R b/R/cal-plot-logistic.R index e257d9e..ae994dd 100644 --- a/R/cal-plot-logistic.R +++ b/R/cal-plot-logistic.R @@ -39,30 +39,34 @@ #' ) #' @seealso [cal_plot_breaks()], [cal_plot_windowed()] #' @export -cal_plot_logistic <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - conf_level = 0.90, - smooth = TRUE, - include_rug = TRUE, - include_ribbon = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_logistic <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + conf_level = 0.90, + smooth = TRUE, + include_rug = TRUE, + include_ribbon = TRUE, + event_level = c("auto", "first", "second"), + ... +) { UseMethod("cal_plot_logistic") } #' @export #' @rdname cal_plot_logistic -cal_plot_logistic.data.frame <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - conf_level = 0.90, - smooth = TRUE, - include_rug = TRUE, - include_ribbon = TRUE, - event_level = c("auto", "first", "second"), - ..., - .by = NULL) { +cal_plot_logistic.data.frame <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + conf_level = 0.90, + smooth = TRUE, + include_rug = TRUE, + include_ribbon = TRUE, + event_level = c("auto", "first", "second"), + ..., + .by = NULL +) { group <- get_group_argument({{ .by }}, .data) .data <- dplyr::group_by(.data, dplyr::across({{ group }})) @@ -81,15 +85,17 @@ cal_plot_logistic.data.frame <- function(.data, } #' @export #' @rdname cal_plot_logistic -cal_plot_logistic.tune_results <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - conf_level = 0.90, - smooth = TRUE, - include_rug = TRUE, - include_ribbon = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_logistic.tune_results <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + conf_level = 0.90, + smooth = TRUE, + include_rug = TRUE, + include_ribbon = TRUE, + event_level = c("auto", "first", "second"), + ... +) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, @@ -114,31 +120,39 @@ cal_plot_logistic.tune_results <- function(.data, #' @export #' @rdname cal_plot_logistic -cal_plot_logistic.grouped_df <- function(.data, - truth = NULL, - estimate = NULL, - conf_level = 0.90, - smooth = TRUE, - include_rug = TRUE, - include_ribbon = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_logistic.grouped_df <- function( + .data, + truth = NULL, + estimate = NULL, + conf_level = 0.90, + smooth = TRUE, + include_rug = TRUE, + include_ribbon = TRUE, + event_level = c("auto", "first", "second"), + ... +) { abort_if_grouped_df() } #--------------------------- >> Implementation --------------------------------- -cal_plot_logistic_impl <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - group = NULL, - conf_level = 0.90, - smooth = TRUE, - include_rug = TRUE, - include_ribbon = TRUE, - event_level = c("auto", "first", "second"), - is_tune_results = FALSE, - ...) { - rlang::arg_match0(event_level, c("auto", "first", "second"), error_call = NULL) +cal_plot_logistic_impl <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + group = NULL, + conf_level = 0.90, + smooth = TRUE, + include_rug = TRUE, + include_ribbon = TRUE, + event_level = c("auto", "first", "second"), + is_tune_results = FALSE, + ... +) { + rlang::arg_match0( + event_level, + c("auto", "first", "second"), + error_call = NULL + ) truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) @@ -175,27 +189,31 @@ cal_plot_logistic_impl <- function(.data, #' @rdname cal_binary_tables #' @export #' @keywords internal -.cal_table_logistic <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - conf_level = 0.90, - smooth = TRUE, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_logistic <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + conf_level = 0.90, + smooth = TRUE, + event_level = c("auto", "first", "second"), + ... +) { UseMethod(".cal_table_logistic") } #' @export #' @keywords internal -.cal_table_logistic.data.frame <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - conf_level = 0.90, - smooth = TRUE, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_logistic.data.frame <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + conf_level = 0.90, + smooth = TRUE, + event_level = c("auto", "first", "second"), + ... +) { .cal_table_logistic_impl( .data = .data, truth = {{ truth }}, @@ -209,14 +227,16 @@ cal_plot_logistic_impl <- function(.data, #' @export #' @keywords internal -.cal_table_logistic.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - conf_level = 0.90, - smooth = TRUE, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_logistic.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + conf_level = 0.90, + smooth = TRUE, + event_level = c("auto", "first", "second"), + ... +) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, @@ -237,14 +257,16 @@ cal_plot_logistic_impl <- function(.data, } #--------------------------- >> Implementation --------------------------------- -.cal_table_logistic_impl <- function(.data, - truth = NULL, - estimate = NULL, - group = NULL, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - smooth = TRUE, - ...) { +.cal_table_logistic_impl <- function( + .data, + truth = NULL, + estimate = NULL, + group = NULL, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + smooth = TRUE, + ... +) { truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) diff --git a/R/cal-plot-regression.R b/R/cal-plot-regression.R index 2aea42b..c9aca11 100644 --- a/R/cal-plot-regression.R +++ b/R/cal-plot-regression.R @@ -30,20 +30,24 @@ #' alpha = 1 / 6, cex = 3, smooth = FALSE #' ) #' @export -cal_plot_regression <- function(.data, - truth = NULL, - estimate = NULL, - smooth = TRUE, - ...) { +cal_plot_regression <- function( + .data, + truth = NULL, + estimate = NULL, + smooth = TRUE, + ... +) { UseMethod("cal_plot_regression") } -cal_plot_regression_impl <- function(.data, - truth = NULL, - estimate = NULL, - smooth = TRUE, - ..., - .by = NULL) { +cal_plot_regression_impl <- function( + .data, + truth = NULL, + estimate = NULL, + smooth = TRUE, + ..., + .by = NULL +) { group <- get_group_argument({{ .by }}, .data) truth <- enquo(truth) @@ -68,11 +72,13 @@ cal_plot_regression.data.frame <- cal_plot_regression_impl #' @export #' @rdname cal_plot_regression -cal_plot_regression.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - smooth = TRUE, - ...) { +cal_plot_regression.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + smooth = TRUE, + ... +) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, @@ -92,16 +98,17 @@ cal_plot_regression.tune_results <- function(.data, #' @export #' @rdname cal_plot_regression -cal_plot_regression.grouped_df <- function(.data, - truth = NULL, - estimate = NULL, - smooth = TRUE, - ...) { +cal_plot_regression.grouped_df <- function( + .data, + truth = NULL, + estimate = NULL, + smooth = TRUE, + ... +) { abort_if_grouped_df() } -regression_plot_impl <- function(.data, truth, estimate, group, - smooth, ...) { +regression_plot_impl <- function(.data, truth, estimate, group, smooth, ...) { truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) @@ -144,7 +151,8 @@ regression_plot_impl <- function(.data, truth, estimate, group, if (smooth) { res <- - res + geom_smooth( + res + + geom_smooth( se = FALSE, col = "blue", method = "gam", @@ -152,7 +160,8 @@ regression_plot_impl <- function(.data, truth, estimate, group, ) } else { res <- - res + geom_smooth( + res + + geom_smooth( se = FALSE, col = "blue", method = "lm", @@ -168,7 +177,6 @@ regression_plot_impl <- function(.data, truth, estimate, group, } - assert_truth_numeric <- function(.data, truth) { truth <- enquo(truth) if (!quo_is_null(truth)) { diff --git a/R/cal-plot-utils.R b/R/cal-plot-utils.R index 2fe6ef1..f33a6f7 100644 --- a/R/cal-plot-utils.R +++ b/R/cal-plot-utils.R @@ -2,8 +2,16 @@ # This function iterates through each of the class levels. For binary it selects # the appropriate one based on the `event_level` selected -.cal_class_grps <- function(.data, truth, cuts, levels, event_level, conf_level, - method = "breaks", smooth = NULL) { +.cal_class_grps <- function( + .data, + truth, + cuts, + levels, + event_level, + conf_level, + method = "breaks", + smooth = NULL +) { truth <- enquo(truth) lev <- process_level(event_level) @@ -18,7 +26,9 @@ } if (length_levels > 2 & lev == 2) { - cli::cli_abort("Only {.val auto} {.arg event_level} is valid for multi-class models.") + cli::cli_abort( + "Only {.val auto} {.arg event_level} is valid for multi-class models." + ) } no_levels <- levels @@ -73,14 +83,16 @@ res } -.cal_model_grps <- function(.data, - truth = NULL, - estimate = NULL, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - lev, - level, - smooth = TRUE) { +.cal_model_grps <- function( + .data, + truth = NULL, + estimate = NULL, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + lev, + level, + smooth = TRUE +) { truth <- enquo(truth) estimate <- enquo(estimate) @@ -128,8 +140,15 @@ } # This function iterates through each breaks/windows of the plot -.cal_cut_grps <- function(.data, truth, estimate, cuts, - level, lev, conf_level) { +.cal_cut_grps <- function( + .data, + truth, + estimate, + cuts, + level, + lev, + conf_level +) { truth <- enquo(truth) estimate <- enquo(estimate) @@ -202,18 +221,23 @@ process_level <- function(x) { ret <- 2 } if (is.null(ret)) { - cli::cli_abort("Invalid {.arg event_level} entry: {x}. Valid entries are - {.val first}, {.val second}, or {.val auto}.", call = NULL) + cli::cli_abort( + "Invalid {.arg event_level} entry: {x}. Valid entries are + {.val first}, {.val second}, or {.val auto}.", + call = NULL + ) } ret } -tune_results_args <- function(.data, - truth, - estimate, - event_level, - parameters = NULL, - ...) { +tune_results_args <- function( + .data, + truth, + estimate, + event_level, + parameters = NULL, + ... +) { if (!(".predictions" %in% colnames(.data))) { rlang::abort( paste0( @@ -259,11 +283,21 @@ tune_results_args <- function(.data, #--------------------------------- Plot ---------------------------------------- -cal_plot_impl <- function(tbl, x, y, - .data, truth, estimate, group, - x_label, y_label, - include_ribbon, include_rug, include_points, - is_tune_results = FALSE) { +cal_plot_impl <- function( + tbl, + x, + y, + .data, + truth, + estimate, + group, + x_label, + y_label, + include_ribbon, + include_rug, + include_points, + is_tune_results = FALSE +) { truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) @@ -290,7 +324,10 @@ cal_plot_impl <- function(tbl, x, y, dplyr_group <- NULL } - res <- ggplot(data = tbl, aes(x = !!x, color = !!dplyr_group, fill = !!dplyr_group)) + + res <- ggplot( + data = tbl, + aes(x = !!x, color = !!dplyr_group, fill = !!dplyr_group) + ) + geom_abline(col = "#aaaaaa", linetype = 2) + geom_line(aes(y = !!y)) @@ -317,7 +354,9 @@ cal_plot_impl <- function(tbl, x, y, level1 <- levels[[1]] if (length(levels) > 1 & !is_tune_results) { - cli::cli_warn("Multiple class columns identified. Using: {.code {level1}}") + cli::cli_warn( + "Multiple class columns identified. Using: {.code {level1}}" + ) } truth_values <- 1:2 @@ -347,10 +386,11 @@ cal_plot_impl <- function(tbl, x, y, theme(aspect.ratio = 1) if (!quo_is_null(group) & length(tbl_groups)) { - res <- res + facet_grid( - rows = vars(!!group), - cols = vars(!!parse_expr(tbl_groups)) - ) + res <- res + + facet_grid( + rows = vars(!!group), + cols = vars(!!parse_expr(tbl_groups)) + ) } else { if (!quo_is_null(group)) { res <- res + facet_wrap(group) diff --git a/R/cal-plot-windowed.R b/R/cal-plot-windowed.R index 67b7271..8dbadc9 100644 --- a/R/cal-plot-windowed.R +++ b/R/cal-plot-windowed.R @@ -42,34 +42,38 @@ #' @inheritParams cal_plot_breaks #' @seealso [cal_plot_breaks()], [cal_plot_logistic()] #' @export -cal_plot_windowed <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_windowed <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ... +) { UseMethod("cal_plot_windowed") } #' @export #' @rdname cal_plot_windowed -cal_plot_windowed.data.frame <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ..., - .by = NULL) { +cal_plot_windowed.data.frame <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ..., + .by = NULL +) { group <- get_group_argument({{ .by }}, .data) .data <- dplyr::group_by(.data, dplyr::across({{ group }})) @@ -91,17 +95,19 @@ cal_plot_windowed.data.frame <- function(.data, #' @export #' @rdname cal_plot_windowed -cal_plot_windowed.tune_results <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_windowed.tune_results <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ... +) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, @@ -128,35 +134,43 @@ cal_plot_windowed.tune_results <- function(.data, #' @export #' @rdname cal_plot_windowed -cal_plot_windowed.grouped_df <- function(.data, - truth = NULL, - estimate = NULL, - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - ...) { +cal_plot_windowed.grouped_df <- function( + .data, + truth = NULL, + estimate = NULL, + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + ... +) { abort_if_grouped_df() } #--------------------------- >> Implementation --------------------------------- -cal_plot_windowed_impl <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - group = NULL, - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - include_ribbon = TRUE, - include_rug = TRUE, - include_points = TRUE, - event_level = c("auto", "first", "second"), - is_tune_results = FALSE, - ...) { - rlang::arg_match0(event_level, c("auto", "first", "second"), error_call = NULL) +cal_plot_windowed_impl <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + group = NULL, + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + include_ribbon = TRUE, + include_rug = TRUE, + include_points = TRUE, + event_level = c("auto", "first", "second"), + is_tune_results = FALSE, + ... +) { + rlang::arg_match0( + event_level, + c("auto", "first", "second"), + error_call = NULL + ) truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) @@ -194,29 +208,33 @@ cal_plot_windowed_impl <- function(.data, #' @rdname cal_binary_tables #' @export #' @keywords internal -.cal_table_windowed <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_windowed <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { UseMethod(".cal_table_windowed") } #' @export #' @keywords internal -.cal_table_windowed.data.frame <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_windowed.data.frame <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { .cal_table_windowed_impl( .data = .data, truth = {{ truth }}, @@ -231,15 +249,17 @@ cal_plot_windowed_impl <- function(.data, #' @export #' @keywords internal -.cal_table_windowed.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - .by = NULL, - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_windowed.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + .by = NULL, + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, @@ -261,15 +281,17 @@ cal_plot_windowed_impl <- function(.data, } #--------------------------- >> Implementation --------------------------------- -.cal_table_windowed_impl <- function(.data, - truth = NULL, - estimate = NULL, - group = NULL, - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - ...) { +.cal_table_windowed_impl <- function( + .data, + truth = NULL, + estimate = NULL, + group = NULL, + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + ... +) { truth <- enquo(truth) estimate <- enquo(estimate) group <- enquo(group) @@ -305,14 +327,16 @@ cal_plot_windowed_impl <- function(.data, res } -.cal_table_windowed_grp <- function(.data, - truth, - window_size = 0.1, - step_size = window_size / 2, - conf_level = 0.90, - event_level = c("auto", "first", "second"), - levels = levels, - ...) { +.cal_table_windowed_grp <- function( + .data, + truth, + window_size = 0.1, + step_size = window_size / 2, + conf_level = 0.90, + event_level = c("auto", "first", "second"), + levels = levels, + ... +) { steps <- seq(0, 1, by = step_size) cuts <- list() cuts$lower_cut <- steps - (window_size / 2) diff --git a/R/cal-utils.R b/R/cal-utils.R index 9a7b48e..3649f9f 100644 --- a/R/cal-utils.R +++ b/R/cal-utils.R @@ -40,7 +40,10 @@ truth_estimate_map <- function(.data, truth, estimate, validate = FALSE) { if (length(estimate_str) == 1) { est_map <- list(sym(estimate_str), NULL) } else { - est_map <- purrr::map(seq_along(truth_levels), ~ sym(estimate_str[[.x]])) + est_map <- purrr::map( + seq_along(truth_levels), + ~ sym(estimate_str[[.x]]) + ) } } if (validate) { diff --git a/R/cal-validate.R b/R/cal-validate.R index 77fc804..54c3b7e 100644 --- a/R/cal-validate.R +++ b/R/cal-validate.R @@ -50,29 +50,35 @@ #' @param .data An `rset` object or the results of [tune::fit_resamples()] with #' a `.predictions` column. #' @export -cal_validate_logistic <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_logistic <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { UseMethod("cal_validate_logistic") } #' @export #' @rdname cal_validate_logistic cal_validate_logistic.resample_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... + ) { cl <- match.call() validation_check(.data, cl) if (!is.null(truth)) { - cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") + cli::cli_warn( + "{.arg truth} is automatically set when this type of object is used." + ) } truth <- tune::.get_tune_outcome_names(.data) # Change splits$data to be the predictions instead of the original @@ -91,12 +97,14 @@ cal_validate_logistic.resample_results <- #' @export #' @rdname cal_validate_logistic -cal_validate_logistic.rset <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_logistic.rset <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { cal_validate( rset = .data, truth = {{ truth }}, @@ -110,12 +118,14 @@ cal_validate_logistic.rset <- function(.data, #' @export #' @rdname cal_validate_logistic -cal_validate_logistic.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_logistic.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + metrics = NULL, + save_pred = FALSE, + ... +) { abort_if_tune_result() } @@ -136,29 +146,35 @@ cal_validate_logistic.tune_results <- function(.data, #' cal_validate_isotonic(Class) #' #' @export -cal_validate_isotonic <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_isotonic <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... +) { UseMethod("cal_validate_isotonic") } #' @export #' @rdname cal_validate_isotonic cal_validate_isotonic.resample_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... + ) { cl <- match.call() validation_check(.data, cl) if (!is.null(truth)) { - cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") + cli::cli_warn( + "{.arg truth} is automatically set when this type of object is used." + ) } truth <- tune::.get_tune_outcome_names(.data) # Change splits$data to be the predictions instead of the original @@ -177,12 +193,14 @@ cal_validate_isotonic.resample_results <- #' @export #' @rdname cal_validate_isotonic -cal_validate_isotonic.rset <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_isotonic.rset <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... +) { cal_validate( rset = .data, truth = {{ truth }}, @@ -196,12 +214,14 @@ cal_validate_isotonic.rset <- function(.data, #' @export #' @rdname cal_validate_isotonic -cal_validate_isotonic.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_isotonic.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + metrics = NULL, + save_pred = FALSE, + ... +) { abort_if_tune_result() } @@ -224,29 +244,35 @@ cal_validate_isotonic.tune_results <- function(.data, #' cal_validate_isotonic_boot(Class) #' #' @export -cal_validate_isotonic_boot <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_isotonic_boot <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... +) { UseMethod("cal_validate_isotonic_boot") } #' @export #' @rdname cal_validate_isotonic_boot cal_validate_isotonic_boot.resample_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... + ) { cl <- match.call() validation_check(.data, cl) if (!is.null(truth)) { - cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") + cli::cli_warn( + "{.arg truth} is automatically set when this type of object is used." + ) } truth <- tune::.get_tune_outcome_names(.data) # Change splits$data to be the predictions instead of the original @@ -265,12 +291,14 @@ cal_validate_isotonic_boot.resample_results <- #' @export #' @rdname cal_validate_isotonic_boot -cal_validate_isotonic_boot.rset <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_isotonic_boot.rset <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... +) { cal_validate( rset = .data, truth = {{ truth }}, @@ -284,12 +312,14 @@ cal_validate_isotonic_boot.rset <- function(.data, #' @export #' @rdname cal_validate_isotonic_boot -cal_validate_isotonic_boot.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_isotonic_boot.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + metrics = NULL, + save_pred = FALSE, + ... +) { abort_if_tune_result() } @@ -312,24 +342,28 @@ cal_validate_isotonic_boot.tune_results <- function(.data, #' cal_validate_beta(Class) #' } #' @export -cal_validate_beta <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_beta <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { UseMethod("cal_validate_beta") } #' @export #' @rdname cal_validate_beta cal_validate_beta.resample_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... + ) { cl <- match.call() validation_check(.data, cl) @@ -337,7 +371,9 @@ cal_validate_beta.resample_results <- validation_check(.data, cl) if (!is.null(truth)) { - cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") + cli::cli_warn( + "{.arg truth} is automatically set when this type of object is used." + ) } truth <- tune::.get_tune_outcome_names(.data) # Change splits$data to be the predictions instead of the original @@ -356,12 +392,14 @@ cal_validate_beta.resample_results <- #' @export #' @rdname cal_validate_beta -cal_validate_beta.rset <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_beta.rset <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { cal_validate( rset = .data, truth = {{ truth }}, @@ -375,12 +413,14 @@ cal_validate_beta.rset <- function(.data, #' @export #' @rdname cal_validate_beta -cal_validate_beta.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_beta.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + metrics = NULL, + save_pred = FALSE, + ... +) { abort_if_tune_result() } @@ -398,29 +438,35 @@ cal_validate_beta.tune_results <- function(.data, #' cal_validate_multinomial(Species) #' #' @export -cal_validate_multinomial <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_multinomial <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { UseMethod("cal_validate_multinomial") } #' @export #' @rdname cal_validate_multinomial cal_validate_multinomial.resample_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... + ) { cl <- match.call() validation_check(.data, cl) if (!is.null(truth)) { - cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") + cli::cli_warn( + "{.arg truth} is automatically set when this type of object is used." + ) } truth <- tune::.get_tune_outcome_names(.data) # Change splits$data to be the predictions instead of the original @@ -439,12 +485,14 @@ cal_validate_multinomial.resample_results <- #' @export #' @rdname cal_validate_multinomial -cal_validate_multinomial.rset <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_multinomial.rset <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { cal_validate( rset = .data, truth = {{ truth }}, @@ -458,12 +506,14 @@ cal_validate_multinomial.rset <- function(.data, #' @export #' @rdname cal_validate_multinomial -cal_validate_multinomial.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_multinomial.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + metrics = NULL, + save_pred = FALSE, + ... +) { abort_if_tune_result() } @@ -500,7 +550,9 @@ check_validation_metrics <- function(metrics, model_mode) { } else if (model_mode == "classification") { allowed <- c("prob_metric", "class_metric") if (any(!(metric_info$class %in% allowed))) { - cli::cli_abort("Metric type should be {.val prob_metric} or {.val class_metric}.") + cli::cli_abort( + "Metric type should be {.val prob_metric} or {.val class_metric}." + ) } } else { cli::cli_abort("Unknown mode {.val {model_mode}}") @@ -510,13 +562,15 @@ check_validation_metrics <- function(metrics, model_mode) { } -cal_validate <- function(rset, - truth = NULL, - estimate = NULL, - cal_function = NULL, - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate <- function( + rset, + truth = NULL, + estimate = NULL, + cal_function = NULL, + metrics = NULL, + save_pred = FALSE, + ... +) { truth <- enquo(truth) estimate <- enquo(estimate) @@ -524,7 +578,8 @@ cal_validate <- function(rset, cli::cli_abort("No calibration function provided") } - outcomes <- dplyr::select(rset$splits[[1]]$data, {{ truth }}) |> purrr::pluck(1) + outcomes <- dplyr::select(rset$splits[[1]]$data, {{ truth }}) |> + purrr::pluck(1) model_mode <- get_problem_type(outcomes) metrics <- check_validation_metrics(metrics, model_mode) @@ -563,7 +618,8 @@ cal_validate <- function(rset, rset$.metrics <- NULL metric_res <- - purrr::map2_dfr(cals, + purrr::map2_dfr( + cals, predictions_out, compute_cal_metrics, metrics = metrics, @@ -592,14 +648,24 @@ pull_pred <- function(x, analysis = TRUE) { preds <- purrr::map(x$splits, as.data.frame, data = what) if (!has_dot_row) { - rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what))) + rows <- purrr::map( + x$splits, + ~ dplyr::tibble(.row = as.integer(.x, data = what)) + ) preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y)) } } preds } -compute_cal_metrics <- function(calib, preds, metrics, truth_col, est_cols, pred = FALSE) { +compute_cal_metrics <- function( + calib, + preds, + metrics, + truth_col, + est_cols, + pred = FALSE +) { if (has_configs(preds)) { configs <- preds$.config } else { @@ -630,7 +696,6 @@ compute_cal_metrics <- function(calib, preds, metrics, truth_col, est_cols, pred } - #' @importFrom pillar type_sum #' @export type_sum.cal_object <- function(x, ...) { @@ -661,29 +726,35 @@ type_sum.cal_object <- function(x, ...) { #' vfold_cv() |> #' cal_validate_linear(truth = outcome, smooth = FALSE, metrics = reg_stats) #' @export -cal_validate_linear <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_linear <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... +) { UseMethod("cal_validate_linear") } #' @export #' @rdname cal_validate_linear cal_validate_linear.resample_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... + ) { cl <- match.call() validation_check(.data, cl) if (!is.null(truth)) { - cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") + cli::cli_warn( + "{.arg truth} is automatically set when this type of object is used." + ) } truth <- tune::.get_tune_outcome_names(.data) # Change splits$data to be the predictions instead of the original @@ -702,12 +773,14 @@ cal_validate_linear.resample_results <- #' @export #' @rdname cal_validate_linear -cal_validate_linear.rset <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_linear.rset <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred"), + metrics = NULL, + save_pred = FALSE, + ... +) { cal_validate( rset = .data, truth = {{ truth }}, @@ -757,29 +830,35 @@ cal_validate_linear.rset <- function(.data, #' collect_metrics() #' #' @export -cal_validate_none <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_none <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { UseMethod("cal_validate_none") } #' @export #' @rdname cal_validate_none cal_validate_none.resample_results <- - function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { + function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... + ) { cl <- match.call() validation_check(.data, cl) if (!is.null(truth)) { - cli::cli_warn("{.arg truth} is automatically set when this type of object is used.") + cli::cli_warn( + "{.arg truth} is automatically set when this type of object is used." + ) } truth <- tune::.get_tune_outcome_names(.data) # Change splits$data to be the predictions instead of the original @@ -798,12 +877,14 @@ cal_validate_none.resample_results <- #' @export #' @rdname cal_validate_none -cal_validate_none.rset <- function(.data, - truth = NULL, - estimate = dplyr::starts_with(".pred_"), - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_none.rset <- function( + .data, + truth = NULL, + estimate = dplyr::starts_with(".pred_"), + metrics = NULL, + save_pred = FALSE, + ... +) { cal_validate( rset = .data, truth = {{ truth }}, @@ -817,12 +898,14 @@ cal_validate_none.rset <- function(.data, #' @export #' @rdname cal_validate_none -cal_validate_none.tune_results <- function(.data, - truth = NULL, - estimate = NULL, - metrics = NULL, - save_pred = FALSE, - ...) { +cal_validate_none.tune_results <- function( + .data, + truth = NULL, + estimate = NULL, + metrics = NULL, + save_pred = FALSE, + ... +) { abort_if_tune_result() } @@ -941,4 +1024,3 @@ validation_check <- function(x, cl = NULL, call = rlang::caller_env()) { } invisible(NULL) } - diff --git a/R/class-pred.R b/R/class-pred.R index ed6b85c..011ae16 100644 --- a/R/class-pred.R +++ b/R/class-pred.R @@ -1,7 +1,14 @@ # ------------------------------------------------------------------------------ # Creation -new_class_pred <- function(x, labels, ordered = FALSE, equivocal = "[EQ]", ..., subclass = NULL) { +new_class_pred <- function( + x, + labels, + ordered = FALSE, + equivocal = "[EQ]", + ..., + subclass = NULL +) { stopifnot(is.integer(x)) stopifnot(is.character(labels)) stopifnot(is.logical(ordered) && length(ordered) == 1L) @@ -77,8 +84,10 @@ class_pred <- function(x = factor(), which = integer(), equivocal = "[EQ]") { # Check for `equivocal` in labels. Not allowed. if (equivocal %in% labs) { - cli::cli_abort("{.arg equivocal} is reserved for equivocal values and must - not already be a level.") + cli::cli_abort( + "{.arg equivocal} is reserved for equivocal values and must + not already be a level." + ) } # rip out the underlying integer structure diff --git a/R/conformal_infer_cv.R b/R/conformal_infer_cv.R index 79c2ebb..4538962 100644 --- a/R/conformal_infer_cv.R +++ b/R/conformal_infer_cv.R @@ -76,7 +76,9 @@ int_conformal_cv <- function(object, ...) { #' @export #' @rdname int_conformal_cv int_conformal_cv.default <- function(object, ...) { - cli::cli_abort("No known {.fn int_conformal_cv} methods for this type of object.") + cli::cli_abort( + "No known {.fn int_conformal_cv} methods for this type of object." + ) } #' @export @@ -106,7 +108,11 @@ int_conformal_cv.tune_results <- function(object, parameters, ...) { y_name <- tune::.get_tune_outcome_names(object) resids <- - tune::collect_predictions(object, parameters = parameters, summarize = TRUE) |> + tune::collect_predictions( + object, + parameters = parameters, + summarize = TRUE + ) |> dplyr::mutate(.abs_resid = abs(.pred - !!rlang::sym(y_name))) new_infer_cv(model_list, resids$.abs_resid) @@ -144,11 +150,13 @@ print.int_conformal_cv <- function(x, ...) { cat("number of models:", format(length(x$models), big.mark = ","), "\n") cat("training set size:", format(length(x$abs_resid), big.mark = ","), "\n\n") - cat("Use `predict(object, new_data, level)` to compute prediction intervals\n") + cat( + "Use `predict(object, new_data, level)` to compute prediction intervals\n" + ) invisible(x) } -#' S3 methods to track which additional packages are needed for prediction +#' S3 methods to track which additional packages are needed for prediction #' intervals via conformal inference #' @param x a conformal interval object #' @inheritParams generics::required_pkgs @@ -156,11 +164,11 @@ print.int_conformal_cv <- function(x, ...) { required_pkgs.int_conformal_cv <- function(x, infra = TRUE, ...) { model_pkgs <- map(x$models, required_pkgs, infra = infra) model_pkgs <- unlist(model_pkgs) - + if (infra) { model_pkgs <- c(model_pkgs, "probably") } - + model_pkgs <- unique(model_pkgs) model_pkgs } @@ -182,7 +190,9 @@ new_infer_cv <- function(models, resid) { } is_wflow <- purrr::map_lgl(models, workflows::is_trained_workflow) if (all(!is_wflow)) { - cli::cli_abort("The {.arg .extracts} argument does not contain fitted workflows.") + cli::cli_abort( + "The {.arg .extracts} argument does not contain fitted workflows." + ) } if (any(!is_wflow)) { models <- models[is_wflow] diff --git a/R/conformal_infer_full.R b/R/conformal_infer_full.R index 6511193..954e616 100644 --- a/R/conformal_infer_full.R +++ b/R/conformal_infer_full.R @@ -69,7 +69,9 @@ int_conformal_full <- function(object, ...) { #' @export #' @rdname int_conformal_full int_conformal_full.default <- function(object, ...) { - cli::cli_abort("No known {.fn int_conformal_full} methods for this type of object.") + cli::cli_abort( + "No known {.fn int_conformal_full} methods for this type of object." + ) } #' @export @@ -84,7 +86,14 @@ int_conformal_full.workflow <- # check req packages pkgs <- required_pkgs(object) - pkgs <- unique(c(pkgs, "workflows", "parsnip", "probably", "mgcv", control$required_pkgs)) + pkgs <- unique(c( + pkgs, + "workflows", + "parsnip", + "probably", + "mgcv", + control$required_pkgs + )) rlang::check_installed(pkgs) control$required_pkgs <- pkgs @@ -109,7 +118,9 @@ print.int_conformal_full <- function(x, ...) { cat("model:", .get_fit_type(x$wflow), "\n") cat("training set size:", format(nrow(x$training), big.mark = ","), "\n\n") - cat("Use `predict(object, new_data, level)` to compute prediction intervals\n") + cat( + "Use `predict(object, new_data, level)` to compute prediction intervals\n" + ) invisible(x) } @@ -122,7 +133,7 @@ required_pkgs.int_conformal_full <- function(x, infra = TRUE, ...) { if (infra) { model_pkgs <- c(model_pkgs, "probably") } - + model_pkgs <- unique(model_pkgs) model_pkgs } @@ -167,9 +178,21 @@ predict.int_conformal_full <- function(object, new_data, level = 0.95, ...) { # compute intervals if (object$control$method == "grid") { - res <- grid_all(new_nest$data, object$wflow, object$training, level, object$control) + res <- grid_all( + new_nest$data, + object$wflow, + object$training, + level, + object$control + ) } else { - res <- optimize_all(new_nest$data, object$wflow, object$training, level, object$control) + res <- optimize_all( + new_nest$data, + object$wflow, + object$training, + level, + object$control + ) } if (object$control$progress) { cat("\n") @@ -209,7 +232,10 @@ get_mode <- function(x) { check_workflow <- function(x, call = rlang::caller_env()) { if (!workflows::is_trained_workflow(x)) { - cli::cli_abort("{.arg object} should be a fitted workflow object.", call = call) + cli::cli_abort( + "{.arg object} should be a fitted workflow object.", + call = call + ) } if (get_mode(x) != "regression") { cli::cli_abort("{.arg object} should be a regression model.", call = call) @@ -239,7 +265,8 @@ var_model <- function(object, train_data, call = caller_env()) { # deviation at a given prediction. var_mod <- try( - mgcv::gam(sq ~ s(.pred), + mgcv::gam( + sq ~ s(.pred), data = train_res, family = stats::Gamma(link = "log") ), @@ -337,16 +364,16 @@ grid_one <- function(new_data, model, train_data, level, ctrl) { new_data[[y_name]] <- NA_real_ trial_data <- dplyr::bind_rows(train_data, new_data) - trial_vals <- seq(pred_val - bound, pred_val + bound, length.out = ctrl$trial_points) + trial_vals <- seq( + pred_val - bound, + pred_val + bound, + length.out = ctrl$trial_points + ) res <- purrr::map_dfr( trial_vals, - ~ trial_fit(.x, - trial_data = trial_data, - wflow = model, - level = level - ) + ~ trial_fit(.x, trial_data = trial_data, wflow = model, level = level) ) compute_bound(res, pred_val) @@ -412,20 +439,30 @@ optimize_one <- function(new_data, model, train_data, level, ctrl) { upper <- try( - stats::uniroot(get_diff, c(pred_val, pred_val + bound), - maxiter = ctrl$max_iter, tol = ctrl$tolerance, + stats::uniroot( + get_diff, + c(pred_val, pred_val + bound), + maxiter = ctrl$max_iter, + tol = ctrl$tolerance, extendInt = "upX", - trial_data, model, level + trial_data, + model, + level ), silent = TRUE ) lower <- try( - stats::uniroot(get_diff, c(pred_val - bound, pred_val), - maxiter = ctrl$max_iter, tol = ctrl$tolerance, + stats::uniroot( + get_diff, + c(pred_val - bound, pred_val), + maxiter = ctrl$max_iter, + tol = ctrl$tolerance, extendInt = "downX", - trial_data, model, level + trial_data, + model, + level ), silent = TRUE ) @@ -473,9 +510,16 @@ get_root <- function(x, ctrl) { #' @return A list object with the options given by the user. #' @export control_conformal_full <- - function(method = "iterative", trial_points = 100, var_multiplier = 10, - max_iter = 100, tolerance = .Machine$double.eps^0.25, progress = FALSE, - required_pkgs = character(0), seed = sample.int(10^5, 1)) { + function( + method = "iterative", + trial_points = 100, + var_multiplier = 10, + max_iter = 100, + tolerance = .Machine$double.eps^0.25, + progress = FALSE, + required_pkgs = character(0), + seed = sample.int(10^5, 1) + ) { method <- rlang::arg_match0(method, c("iterative", "grid")) list( diff --git a/R/conformal_infer_quantile.R b/R/conformal_infer_quantile.R index 3f7803c..c3a033f 100644 --- a/R/conformal_infer_quantile.R +++ b/R/conformal_infer_quantile.R @@ -124,7 +124,7 @@ required_pkgs.int_conformal_quantile <- function(x, infra = TRUE, ...) { if (infra) { model_pkgs <- c(model_pkgs, "probably") } - + model_pkgs <- unique(model_pkgs) model_pkgs } diff --git a/R/conformal_infer_split.R b/R/conformal_infer_split.R index a8f38c3..696f0e0 100644 --- a/R/conformal_infer_split.R +++ b/R/conformal_infer_split.R @@ -65,7 +65,9 @@ int_conformal_split <- function(object, ...) { #' @export #' @rdname int_conformal_split int_conformal_split.default <- function(object, ...) { - cli::cli_abort("No known {.fn int_conformal_split} methods for this type of object.") + cli::cli_abort( + "No known {.fn int_conformal_split} methods for this type of object." + ) } #' @export @@ -77,7 +79,11 @@ int_conformal_split.workflow <- function(object, cal_data, ...) { y_name <- names(hardhat::extract_mold(object)$outcomes) cal_pred <- generics::augment(object, cal_data) cal_pred$.resid <- cal_pred[[y_name]] - cal_pred$.pred - res <- list(resid = sort(abs(cal_pred$.resid)), wflow = object, n = nrow(cal_pred)) + res <- list( + resid = sort(abs(cal_pred$.resid)), + wflow = object, + n = nrow(cal_pred) + ) class(res) <- c("conformal_reg_split", "int_conformal_split") res } @@ -90,7 +96,9 @@ print.int_conformal_split <- function(x, ...) { cat("model:", .get_fit_type(x$wflow), "\n") cat("calibration set size:", format(x$n, big.mark = ","), "\n\n") - cat("Use `predict(object, new_data, level)` to compute prediction intervals\n") + cat( + "Use `predict(object, new_data, level)` to compute prediction intervals\n" + ) invisible(x) } @@ -103,7 +111,7 @@ required_pkgs.int_conformal_split <- function(x, infra = TRUE, ...) { if (infra) { model_pkgs <- c(model_pkgs, "probably") } - + model_pkgs <- unique(model_pkgs) model_pkgs } diff --git a/R/make_class_pred.R b/R/make_class_pred.R index 2116266..40967f5 100644 --- a/R/make_class_pred.R +++ b/R/make_class_pred.R @@ -69,10 +69,12 @@ #' ) #' #' @export -make_class_pred <- function(..., - levels, - ordered = FALSE, - min_prob = 1 / length(levels)) { +make_class_pred <- function( + ..., + levels, + ordered = FALSE, + min_prob = 1 / length(levels) +) { dots <- rlang::quos(...) probs <- lapply(dots, rlang::eval_tidy) @@ -129,11 +131,13 @@ make_class_pred <- function(..., #' @rdname make_class_pred #' @export -make_two_class_pred <- function(estimate, - levels, - threshold = 0.5, - ordered = FALSE, - buffer = NULL) { +make_two_class_pred <- function( + estimate, + levels, + threshold = 0.5, + ordered = FALSE, + buffer = NULL +) { if (length(levels) != 2 || !is.character(levels)) { cli::cli_abort("{.arg levels} must be a character vector of length 2.") } @@ -221,12 +225,14 @@ make_two_class_pred <- function(estimate, #' ) #' #' @export -append_class_pred <- function(.data, - ..., - levels, - ordered = FALSE, - min_prob = 1 / length(levels), - name = ".class_pred") { +append_class_pred <- function( + .data, + ..., + levels, + ordered = FALSE, + min_prob = 1 / length(levels), + name = ".class_pred" +) { if (!is.data.frame(.data) && ncol(.data) < 2) { cli::cli_abort( "{.arg .data} should be a data frame or tibble with at least 2 columns." diff --git a/R/printing.R b/R/printing.R index 5e86d11..742b56b 100644 --- a/R/printing.R +++ b/R/printing.R @@ -37,7 +37,6 @@ cat_levels <- function(x, width = getOption("width")) { drop <- n_lev > maxl cat( - # Print number of levels if we had to drop some if (drop) { paste(format(n_lev), "") @@ -46,14 +45,11 @@ cat_levels <- function(x, width = getOption("width")) { # Print `Levels: ` header, paste( - # `first levels ... last levels` if (drop) { c(lev[1L:max(1, maxl - 1)], "...", if (maxl > 1) lev[n_lev]) - } - - # print all levels - else { + } else { + # print all levels lev }, collapse = colsep diff --git a/R/probably-package.R b/R/probably-package.R index 7387a08..1f5fcb8 100644 --- a/R/probably-package.R +++ b/R/probably-package.R @@ -13,8 +13,23 @@ NULL utils::globalVariables(c( - ".bin", ".is_val", "event_rate", "events", "lower", - "predicted_midpoint", "total", "upper", ".config", - ".adj_estimate", ".rounded", ".pred", ".bound", "pred_val", ".extracts", - ".x", ".type", ".metrics", "cal_data" + ".bin", + ".is_val", + "event_rate", + "events", + "lower", + "predicted_midpoint", + "total", + "upper", + ".config", + ".adj_estimate", + ".rounded", + ".pred", + ".bound", + "pred_val", + ".extracts", + ".x", + ".type", + ".metrics", + "cal_data" )) diff --git a/R/threshold_perf.R b/R/threshold_perf.R index 41b4376..0cfc178 100644 --- a/R/threshold_perf.R +++ b/R/threshold_perf.R @@ -93,14 +93,16 @@ threshold_perf <- function(.data, ...) { #' @rdname threshold_perf #' @export -threshold_perf.data.frame <- function(.data, - truth, - estimate, - thresholds = NULL, - metrics = NULL, - na_rm = TRUE, - event_level = "first", - ...) { +threshold_perf.data.frame <- function( + .data, + truth, + estimate, + thresholds = NULL, + metrics = NULL, + na_rm = TRUE, + event_level = "first", + ... +) { if (is.null(thresholds)) { thresholds <- seq(0.5, 1, length = 21) } @@ -187,7 +189,6 @@ threshold_perf.data.frame <- function(.data, .data <- .data |> dplyr::group_by(.threshold) } - .data_metrics <- metrics( .data, truth = truth, @@ -239,7 +240,7 @@ check_thresholded_metrics <- function(x) { # check to see if sensitivity and specificity are in the lists has_sens <- any(y$metric %in% c("sens", "sensitivity")) & - any(y$metric %in% c("spec", "specificity")) + any(y$metric %in% c("spec", "specificity")) has_sens } diff --git a/R/utils.R b/R/utils.R index 896aa89..cfb144c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -15,7 +15,9 @@ quote_collapse <- function(x, quote = "`", collapse = ", ") { abort_default <- function(x, fn) { cls <- quote_collapse(class(x)) - cli::cli_abort("No implementation of {.fn {fn}} for {.obj_type_friendly {cls}}.") + cli::cli_abort( + "No implementation of {.fn {fn}} for {.obj_type_friendly {cls}}." + ) } # Check if a class_pred object came from an ordered factor diff --git a/R/vctrs-compat.R b/R/vctrs-compat.R index 1a86333..94730a3 100644 --- a/R/vctrs-compat.R +++ b/R/vctrs-compat.R @@ -81,7 +81,13 @@ vec_cast.ordered.class_pred <- function(x, to, ...) { } #' @export -vec_cast.class_pred.character <- function(x, to, ..., x_arg = "x", to_arg = "to") { +vec_cast.class_pred.character <- function( + x, + to, + ..., + x_arg = "x", + to_arg = "to" +) { # first cast character -> factor # then add class pred attributes diff --git a/air.toml b/air.toml new file mode 100644 index 0000000..e69de29 diff --git a/tests/testthat/helper-cal.R b/tests/testthat/helper-cal.R index 57f7564..1700691 100644 --- a/tests/testthat/helper-cal.R +++ b/tests/testthat/helper-cal.R @@ -238,15 +238,22 @@ testthat_cal_fit_rs <- function() { modeldata::sim_classification(100)[, 1:3] |> dplyr::rename(outcome = class) |> vfold_cv() |> - fit_resamples(logistic_reg(), outcome ~ ., resamples = ., control = ctrl) + fit_resamples( + logistic_reg(), + outcome ~ ., + resamples = ., + control = ctrl + ) set.seed(112) rs_mlt <- sim_multinom_df(500) |> dplyr::rename(outcome = class) |> vfold_cv() |> - fit_resamples(mlp() |> set_mode("classification"), + fit_resamples( + mlp() |> set_mode("classification"), outcome ~ ., - resamples = ., control = ctrl + resamples = ., + control = ctrl ) set.seed(113) rs_reg <- @@ -333,7 +340,9 @@ are_groups_configs <- function(x) { bin_with_configs <- function() { set.seed(1) segment_logistic |> - dplyr::mutate(.config = sample(letters[1:2], nrow(segment_logistic), replace = TRUE)) + dplyr::mutate( + .config = sample(letters[1:2], nrow(segment_logistic), replace = TRUE) + ) } mnl_with_configs <- function() { @@ -341,7 +350,9 @@ mnl_with_configs <- function() { set.seed(1) modeldata::hpc_cv |> - dplyr::mutate(.config = sample(letters[1:2], nrow(modeldata::hpc_cv), replace = TRUE)) + dplyr::mutate( + .config = sample(letters[1:2], nrow(modeldata::hpc_cv), replace = TRUE) + ) } reg_with_configs <- function() { @@ -350,7 +361,13 @@ reg_with_configs <- function() { set.seed(1) modeldata::solubility_test |> - dplyr::mutate(.config = sample(letters[1:2], nrow(modeldata::solubility_test), replace = TRUE)) + dplyr::mutate( + .config = sample( + letters[1:2], + nrow(modeldata::solubility_test), + replace = TRUE + ) + ) } holdout_length <- function(x) { diff --git a/tests/testthat/test-bound_prediction.R b/tests/testthat/test-bound_prediction.R index c5dfacf..1ee197b 100644 --- a/tests/testthat/test-bound_prediction.R +++ b/tests/testthat/test-bound_prediction.R @@ -7,15 +7,19 @@ test_that("lower_limit bounds for numeric predictions", { # ------------------------------------------------------------------------------ - expect_snapshot(bound_prediction(modeldata::solubility_test, lower_limit = 2), error = TRUE) + expect_snapshot( + bound_prediction(modeldata::solubility_test, lower_limit = 2), + error = TRUE + ) expect_snapshot( modeldata::solubility_test |> mutate(.pred = format(prediction)) |> bound_prediction(lower_limit = 2), - error = TRUE) + error = TRUE + ) sol <- modeldata::solubility_test |> set_names(c("solubility", ".pred")) - + expect_equal(bound_prediction(sol), sol) expect_equal(bound_prediction(sol, lower_limit = NA), sol) @@ -24,7 +28,10 @@ test_that("lower_limit bounds for numeric predictions", { expect_equal(res_1$.pred[sol$.pred >= -1], sol$.pred[sol$.pred >= -1]) expect_snapshot(bound_prediction(sol, lower_limit = tune2()), error = TRUE) - expect_snapshot(bound_prediction(as.matrix(sol), lower_limit = 1), error = TRUE) + expect_snapshot( + bound_prediction(as.matrix(sol), lower_limit = 1), + error = TRUE + ) }) test_that("upper_limit bounds for numeric predictions", { @@ -36,12 +43,16 @@ test_that("upper_limit bounds for numeric predictions", { # ------------------------------------------------------------------------------ - expect_snapshot(bound_prediction(modeldata::solubility_test, lower_limit = 2), error = TRUE) + expect_snapshot( + bound_prediction(modeldata::solubility_test, lower_limit = 2), + error = TRUE + ) expect_snapshot( modeldata::solubility_test |> mutate(.pred = format(prediction)) |> bound_prediction(lower_limit = 2), - error = TRUE) + error = TRUE + ) sol <- modeldata::solubility_test |> set_names(c("solubility", ".pred")) diff --git a/tests/testthat/test-cal-apply-binary.R b/tests/testthat/test-cal-apply-binary.R index 5dcfdd6..cf53374 100644 --- a/tests/testthat/test-cal-apply-binary.R +++ b/tests/testthat/test-cal-apply-binary.R @@ -38,4 +38,4 @@ test_that("Logistic spline apply works - tune_results", { testthat_cal_binary_count(), nrow(tap_gam) ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-cal-apply-multi.R b/tests/testthat/test-cal-apply-multi.R index 4a4d208..ff2d193 100644 --- a/tests/testthat/test-cal-apply-multi.R +++ b/tests/testthat/test-cal-apply-multi.R @@ -1,5 +1,9 @@ test_that("Multinomial apply works - data.frame", { - sl_multinomial <- cal_estimate_multinomial(species_probs, Species, smooth = FALSE) + sl_multinomial <- cal_estimate_multinomial( + species_probs, + Species, + smooth = FALSE + ) ap_multinomial <- cal_apply(species_probs, sl_multinomial) pred_bobcat <- ap_multinomial$.pred_bobcat @@ -13,7 +17,7 @@ test_that("Logistic apply works - tune_results", { tct <- testthat_cal_multiclass() tl_multinomial <- cal_estimate_multinomial(tct, smooth = FALSE) tap_multinomial <- cal_apply(tct, tl_multinomial) - + expect_equal( testthat_cal_multiclass_count(), nrow(tap_multinomial) diff --git a/tests/testthat/test-cal-apply-regression.R b/tests/testthat/test-cal-apply-regression.R index 132099f..5517b09 100644 --- a/tests/testthat/test-cal-apply-regression.R +++ b/tests/testthat/test-cal-apply-regression.R @@ -1,5 +1,9 @@ test_that("Linear apply works - data.frame", { - sl_linear <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) + sl_linear <- cal_estimate_linear( + boosting_predictions_oob, + outcome, + smooth = FALSE + ) ap_linear <- cal_apply(boosting_predictions_oob, sl_linear) pred <- ap_linear$.pred @@ -34,4 +38,4 @@ test_that("Linear spline apply works - tune_results", { testthat_cal_reg_count(), nrow(tap_gam) ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test-cal-estimate-beta.R b/tests/testthat/test-cal-estimate-beta.R index 411819f..d1a7ed6 100644 --- a/tests/testthat/test-cal-estimate-beta.R +++ b/tests/testthat/test-cal-estimate-beta.R @@ -24,7 +24,6 @@ test_that("Beta estimates work - data.frame", { dplyr::mutate(group1 = 1, group2 = 2) |> cal_estimate_beta(Class, smooth = FALSE, .by = c(group1, group2)) ) - }) test_that("Beta estimates work - tune_results", { diff --git a/tests/testthat/test-cal-estimate-isotonic.R b/tests/testthat/test-cal-estimate-isotonic.R index f70b489..fa7f7c0 100644 --- a/tests/testthat/test-cal-estimate-isotonic.R +++ b/tests/testthat/test-cal-estimate-isotonic.R @@ -24,7 +24,6 @@ test_that("Isotonic estimates work - data.frame", { dplyr::mutate(group1 = 1, group2 = 2) |> cal_estimate_isotonic(Class, .by = c(group1, group2)) ) - }) test_that("Isotonic estimates work - tune_results", { @@ -66,7 +65,11 @@ test_that("Isotonic linear estimates work - data.frame", { skip_if_not_installed("modeldata") set.seed(2983) - sl_logistic <- cal_estimate_isotonic(boosting_predictions_oob, outcome, estimate = .pred) + sl_logistic <- cal_estimate_isotonic( + boosting_predictions_oob, + outcome, + estimate = .pred + ) expect_cal_type(sl_logistic, "regression") expect_cal_method(sl_logistic, "Isotonic regression calibration") expect_cal_rows(sl_logistic, 2000) @@ -103,7 +106,10 @@ test_that("Isotonic Bootstrapped estimates work - data.frame", { cal_estimate_isotonic_boot(Class, .by = group) expect_cal_type(sl_boot_group, "binary") - expect_cal_method(sl_boot_group, "Bootstrapped isotonic regression calibration") + expect_cal_method( + sl_boot_group, + "Bootstrapped isotonic regression calibration" + ) expect_snapshot(print(sl_boot_group)) expect_snapshot_error( @@ -111,7 +117,6 @@ test_that("Isotonic Bootstrapped estimates work - data.frame", { dplyr::mutate(group1 = 1, group2 = 2) |> cal_estimate_isotonic_boot(Class, .by = c(group1, group2)) ) - }) test_that("Isotonic Bootstrapped estimates work - tune_results", { @@ -134,7 +139,10 @@ test_that("Isotonic Bootstrapped estimates work - tune_results", { set.seed(100) mtnl_isotonic <- cal_estimate_isotonic_boot(testthat_cal_multiclass()) expect_cal_type(mtnl_isotonic, "one_vs_all") - expect_cal_method(mtnl_isotonic, "Bootstrapped isotonic regression calibration") + expect_cal_method( + mtnl_isotonic, + "Bootstrapped isotonic regression calibration" + ) expect_snapshot(print(mtnl_isotonic)) expect_equal( @@ -174,11 +182,14 @@ test_that("non-standard column names", { seg <- segment_logistic |> rename_with(~ paste0(.x, "-1"), matches(".pred")) |> mutate( - Class = paste0(Class,"-1"), + Class = paste0(Class, "-1"), Class = factor(Class), .pred_class = ifelse(`.pred_poor-1` >= 0.5, "poor-1", "good-1") ) calib <- cal_estimate_isotonic(seg, Class) new_pred <- cal_apply(seg, calib, pred_class = .pred_class) - expect_named(new_pred, c(".pred_poor-1", ".pred_good-1", "Class", ".pred_class")) + expect_named( + new_pred, + c(".pred_poor-1", ".pred_good-1", "Class", ".pred_class") + ) }) diff --git a/tests/testthat/test-cal-estimate-linear.R b/tests/testthat/test-cal-estimate-linear.R index dd2a6e2..9b64a3a 100644 --- a/tests/testthat/test-cal-estimate-linear.R +++ b/tests/testthat/test-cal-estimate-linear.R @@ -1,7 +1,11 @@ test_that("Linear estimates work - data.frame", { skip_if_not_installed("modeldata") - sl_linear <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) + sl_linear <- cal_estimate_linear( + boosting_predictions_oob, + outcome, + smooth = FALSE + ) expect_cal_type(sl_linear, "regression") expect_cal_method(sl_linear, "Linear calibration") expect_cal_estimate(sl_linear, "butchered_glm") @@ -27,7 +31,6 @@ test_that("Linear estimates work - data.frame", { dplyr::mutate(group1 = 1, group2 = 2) |> cal_estimate_linear(outcome, smooth = FALSE, .by = c(group1, group2)) ) - }) test_that("Linear estimates work - tune_results", { @@ -36,7 +39,6 @@ test_that("Linear estimates work - tune_results", { expect_cal_method(tl_linear, "Linear calibration") expect_cal_estimate(tl_linear, "butchered_glm") expect_snapshot(print(tl_linear)) - }) test_that("Linear estimates errors - grouped_df", { @@ -55,7 +57,7 @@ test_that("Linear spline estimates work - data.frame", { expect_cal_estimate(sl_gam, "butchered_gam") expect_cal_rows(sl_gam, 2000) expect_snapshot(print(sl_gam)) - expect_equal( + expect_equal( required_pkgs(sl_gam), c("mgcv", "probably") ) @@ -99,9 +101,17 @@ test_that("Linear spline switches to linear if too few unique", { ) expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = TRUE) + sl_gam <- cal_estimate_linear( + boosting_predictions_oob, + outcome, + smooth = TRUE + ) + ) + sl_lm <- cal_estimate_linear( + boosting_predictions_oob, + outcome, + smooth = FALSE ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, smooth = FALSE) expect_identical( sl_gam$estimates, @@ -109,13 +119,22 @@ test_that("Linear spline switches to linear if too few unique", { ) expect_snapshot( - sl_gam <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = TRUE) + sl_gam <- cal_estimate_linear( + boosting_predictions_oob, + outcome, + .by = id, + smooth = TRUE + ) + ) + sl_lm <- cal_estimate_linear( + boosting_predictions_oob, + outcome, + .by = id, + smooth = FALSE ) - sl_lm <- cal_estimate_linear(boosting_predictions_oob, outcome, .by = id, smooth = FALSE) expect_identical( sl_gam$estimates, sl_lm$estimates ) }) - diff --git a/tests/testthat/test-cal-estimate-logistic.R b/tests/testthat/test-cal-estimate-logistic.R index 39e0bb0..187f5e6 100644 --- a/tests/testthat/test-cal-estimate-logistic.R +++ b/tests/testthat/test-cal-estimate-logistic.R @@ -9,7 +9,8 @@ test_that("Logistic estimates work - data.frame", { expect_snapshot(print(sl_logistic)) expect_snapshot_error( - segment_logistic |> cal_estimate_logistic(truth = Class, estimate = .pred_poor) + segment_logistic |> + cal_estimate_logistic(truth = Class, estimate = .pred_poor) ) data(hpc_cv, package = "yardstick") @@ -31,7 +32,6 @@ test_that("Logistic estimates work - data.frame", { "probably" ) - expect_snapshot_error( segment_logistic |> dplyr::mutate(group1 = 1, group2 = 2) |> @@ -41,12 +41,20 @@ test_that("Logistic estimates work - data.frame", { # ------------------------------------------------------------------------------ data(two_class_example, package = "modeldata") - two_cls_plist <- two_class_example[0,] + two_cls_plist <- two_class_example[0, ] two_cls_mod <- - cal_estimate_logistic(two_class_example, truth = truth, estimate = c(Class1, Class2)) - - two_cls_res <- cal_apply(two_class_example, two_cls_mod, pred_class = predicted) - expect_equal(two_cls_res[0,], two_cls_plist) + cal_estimate_logistic( + two_class_example, + truth = truth, + estimate = c(Class1, Class2) + ) + + two_cls_res <- cal_apply( + two_class_example, + two_cls_mod, + pred_class = predicted + ) + expect_equal(two_cls_res[0, ], two_cls_plist) expect_equal( required_pkgs(two_cls_mod), c("mgcv", "probably") @@ -107,7 +115,7 @@ test_that("Logistic spline estimates work - tune_results", { expect_cal_method(tl_gam, "Generalized additive model calibration") expect_cal_estimate(tl_gam, "butchered_gam") expect_snapshot(print(tl_gam)) - expect_equal( + expect_equal( required_pkgs(tl_gam), c("mgcv", "probably") ) @@ -141,9 +149,19 @@ test_that("Logistic spline switches to linear if too few unique", { length.out = nrow(segment_logistic) ) expect_snapshot( - sl_gam <- cal_estimate_logistic(segment_logistic, Class, .by = id, smooth = TRUE) + sl_gam <- cal_estimate_logistic( + segment_logistic, + Class, + .by = id, + smooth = TRUE + ) + ) + sl_lm <- cal_estimate_logistic( + segment_logistic, + Class, + .by = id, + smooth = FALSE ) - sl_lm <- cal_estimate_logistic(segment_logistic, Class, .by = id, smooth = FALSE) expect_identical( sl_gam$estimates, diff --git a/tests/testthat/test-cal-estimate-multinomial.R b/tests/testthat/test-cal-estimate-multinomial.R index 5228f22..c84b8fd 100644 --- a/tests/testthat/test-cal-estimate-multinomial.R +++ b/tests/testthat/test-cal-estimate-multinomial.R @@ -12,7 +12,11 @@ test_that("Multinomial estimates work - data.frame", { c("nnet", "probably") ) - sp_smth_multi <- cal_estimate_multinomial(species_probs, Species, smooth = TRUE) + sp_smth_multi <- cal_estimate_multinomial( + species_probs, + Species, + smooth = TRUE + ) expect_cal_type(sp_smth_multi, "multiclass") expect_cal_method(sp_smth_multi, "Generalized additive model calibration") expect_cal_rows(sp_smth_multi, n = 110) @@ -46,7 +50,10 @@ test_that("Multinomial estimates work - tune_results", { skip_if_not_installed("modeldata") skip_if_not_installed("nnet") - tl_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = FALSE) + tl_multi <- cal_estimate_multinomial( + testthat_cal_multiclass(), + smooth = FALSE + ) expect_cal_type(tl_multi, "multiclass") expect_cal_method(tl_multi, "Multinomial regression calibration") expect_snapshot(print(tl_multi)) @@ -60,7 +67,10 @@ test_that("Multinomial estimates work - tune_results", { nrow() ) - tl_smth_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = TRUE) + tl_smth_multi <- cal_estimate_multinomial( + testthat_cal_multiclass(), + smooth = TRUE + ) expect_cal_type(tl_smth_multi, "multiclass") expect_cal_method(tl_smth_multi, "Generalized additive model calibration") expect_snapshot(print(tl_smth_multi)) @@ -98,9 +108,17 @@ test_that("Multinomial spline switches to linear if too few unique", { dplyr::slice_head(n = 2, by = Species) expect_snapshot( - sl_gam <- cal_estimate_multinomial(smol_species_probs, Species, smooth = TRUE) + sl_gam <- cal_estimate_multinomial( + smol_species_probs, + Species, + smooth = TRUE + ) + ) + sl_glm <- cal_estimate_multinomial( + smol_species_probs, + Species, + smooth = FALSE ) - sl_glm <- cal_estimate_multinomial(smol_species_probs, Species, smooth = FALSE) expect_identical( sl_gam$estimates, @@ -113,9 +131,19 @@ test_that("Multinomial spline switches to linear if too few unique", { dplyr::mutate(id = rep(1:2, 6)) expect_snapshot( - sl_gam <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = TRUE) + sl_gam <- cal_estimate_multinomial( + smol_by_species_probs, + Species, + .by = id, + smooth = TRUE + ) + ) + sl_glm <- cal_estimate_multinomial( + smol_by_species_probs, + Species, + .by = id, + smooth = FALSE ) - sl_glm <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = FALSE) expect_identical( sl_gam$estimates, diff --git a/tests/testthat/test-cal-estimate-none.R b/tests/testthat/test-cal-estimate-none.R index aff989c..2bd010c 100644 --- a/tests/testthat/test-cal-estimate-none.R +++ b/tests/testthat/test-cal-estimate-none.R @@ -75,7 +75,6 @@ test_that("no calibration works - data.frame", { dplyr::mutate(group1 = 1, group2 = 2) |> cal_estimate_none(Species, .by = c(group1, group2)) ) - }) test_that("no calibration works - tune_results", { @@ -119,14 +118,10 @@ test_that("no calibration works - tune_results", { cal_apply(multi_pred, nope_multi), multi_pred ) - }) test_that("no calibration fails - grouped_df", { - expect_snapshot_error( cal_estimate_none(dplyr::group_by(mtcars, vs)) ) - }) - diff --git a/tests/testthat/test-cal-plot-breaks.R b/tests/testthat/test-cal-plot-breaks.R index d2ed6e0..62d7fa4 100644 --- a/tests/testthat/test-cal-plot-breaks.R +++ b/tests/testthat/test-cal-plot-breaks.R @@ -1,5 +1,10 @@ test_that("Binary breaks functions work", { - x10 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "first") + x10 <- .cal_table_breaks( + segment_logistic, + Class, + .pred_good, + event_level = "first" + ) expect_equal( x10$predicted_midpoint, @@ -38,11 +43,15 @@ test_that("Binary breaks functions work with group argument", { expect_s3_class(res, "ggplot") expect_equal( - res$data[0,], + res$data[0, ], dplyr::tibble( id = factor(0, levels = paste(0:1)), - predicted_midpoint = double(), event_rate = double(), events = double(), - total = integer(), lower = double(), upper = double() + predicted_midpoint = double(), + event_rate = double(), + events = double(), + total = integer(), + lower = double(), + upper = double() ) ) @@ -166,14 +175,24 @@ test_that("custom names for cal_plot_breaks()", { }) test_that("Event level handling works", { - x7 <- .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "second") + x7 <- .cal_table_breaks( + segment_logistic, + Class, + .pred_good, + event_level = "second" + ) expect_equal( which(x7$predicted_midpoint == min(x7$predicted_midpoint)), which(x7$event_rate == max(x7$event_rate)) ) expect_snapshot_error( - .cal_table_breaks(segment_logistic, Class, .pred_good, event_level = "invalid") + .cal_table_breaks( + segment_logistic, + Class, + .pred_good, + event_level = "invalid" + ) ) }) diff --git a/tests/testthat/test-cal-plot-logistic.R b/tests/testthat/test-cal-plot-logistic.R index 394cb53..75f8968 100644 --- a/tests/testthat/test-cal-plot-logistic.R +++ b/tests/testthat/test-cal-plot-logistic.R @@ -3,12 +3,14 @@ test_that("Binary logistic functions work", { x20 <- .cal_table_logistic(segment_logistic, Class, .pred_good) - model20 <- mgcv::gam(Class ~ s(.pred_good, k = 10), + model20 <- mgcv::gam( + Class ~ s(.pred_good, k = 10), data = segment_logistic, family = binomial() ) - preds20 <- predict(model20, + preds20 <- predict( + model20, data.frame(.pred_good = seq(0, 1, by = .01)), type = "response" ) @@ -23,22 +25,24 @@ test_that("Binary logistic functions work", { x22 <- .cal_table_logistic(testthat_cal_binary()) - x22_1 <- testthat_cal_binary() |> tune::collect_predictions(summarize = TRUE) |> dplyr::group_by(.config) |> - dplyr::group_map(~ { - model <- mgcv::gam( - class ~ s(.pred_class_1, k = 10), - data = .x, - family = binomial() - ) - preds <- predict(model, - data.frame(.pred_class_1 = seq(0, 1, by = .01)), - type = "response" - ) - 1 - preds - }) |> + dplyr::group_map( + ~ { + model <- mgcv::gam( + class ~ s(.pred_class_1, k = 10), + data = .x, + family = binomial() + ) + preds <- predict( + model, + data.frame(.pred_class_1 = seq(0, 1, by = .01)), + type = "response" + ) + 1 - preds + } + ) |> purrr::reduce(c) expect_equal(sd(x22$prob), sd(x22_1), tolerance = 0.000001) @@ -49,11 +53,21 @@ test_that("Binary logistic functions work", { expect_s3_class(x23, "ggplot") expect_true(has_facet(x23)) - x24 <- .cal_table_logistic(segment_logistic, Class, .pred_good, smooth = FALSE) + x24 <- .cal_table_logistic( + segment_logistic, + Class, + .pred_good, + smooth = FALSE + ) - model24 <- stats::glm(Class ~ .pred_good, data = segment_logistic, family = binomial()) + model24 <- stats::glm( + Class ~ .pred_good, + data = segment_logistic, + family = binomial() + ) - preds24 <- predict(model24, + preds24 <- predict( + model24, data.frame(.pred_good = seq(0, 1, by = .01)), type = "response" ) @@ -74,7 +88,8 @@ test_that("Binary logistic functions work", { ) lgst_configs <- - bin_with_configs() |> cal_plot_logistic(truth = Class, estimate = .pred_good) + bin_with_configs() |> + cal_plot_logistic(truth = Class, estimate = .pred_good) expect_true(has_facet(lgst_configs)) # ------------------------------------------------------------------------------ @@ -86,7 +101,6 @@ test_that("Binary logistic functions work", { # should be faceted by .config and class expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) - multi_configs_from_df <- mnl_with_configs() |> cal_plot_logistic(truth = obs, estimate = c(VF:L)) expect_s3_class(multi_configs_from_df, "ggplot") @@ -109,10 +123,13 @@ test_that("Binary logistic functions work with group argument", { expect_s3_class(res, "ggplot") expect_equal( - res$data[0,], + res$data[0, ], dplyr::tibble( id = factor(0, levels = paste(0:1)), - estimate = double(), prob = double(), lower = double(), upper = double() + estimate = double(), + prob = double(), + lower = double(), + upper = double() ) ) @@ -140,7 +157,8 @@ test_that("Binary logistic functions work with group argument", { ) lgst_configs <- - bin_with_configs() |> cal_plot_logistic(truth = Class, estimate = .pred_good) + bin_with_configs() |> + cal_plot_logistic(truth = Class, estimate = .pred_good) expect_true(has_facet(lgst_configs)) }) diff --git a/tests/testthat/test-cal-plot-regression.R b/tests/testthat/test-cal-plot-regression.R index cc49aa6..7af9904 100644 --- a/tests/testthat/test-cal-plot-regression.R +++ b/tests/testthat/test-cal-plot-regression.R @@ -8,7 +8,7 @@ test_that("regression functions work", { expect_s3_class(res, "ggplot") expect_equal( - res$data[0,], + res$data[0, ], dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0)) ) @@ -31,7 +31,7 @@ test_that("regression functions work", { expect_s3_class(res, "ggplot") expect_equal( - res$data[0,], + res$data[0, ], dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character(0)) ) @@ -55,10 +55,14 @@ test_that("regression functions work", { skip_if_not_installed("tune", "1.2.0") expect_equal( - res$data[0,], - dplyr::tibble(.pred = numeric(0), .row = numeric(0), - predictor_01 = integer(0), outcome = numeric(0), - .config = character()) + res$data[0, ], + dplyr::tibble( + .pred = numeric(0), + .row = numeric(0), + predictor_01 = integer(0), + outcome = numeric(0), + .config = character() + ) ) expect_equal( @@ -81,10 +85,14 @@ test_that("regression functions work", { skip_if_not_installed("tune", "1.2.0") expect_equal( - res$data[0,], - dplyr::tibble(.pred = numeric(0), .row = numeric(0), - predictor_01 = integer(0), outcome = numeric(0), - .config = character()) + res$data[0, ], + dplyr::tibble( + .pred = numeric(0), + .row = numeric(0), + predictor_01 = integer(0), + outcome = numeric(0), + .config = character() + ) ) expect_equal( @@ -102,13 +110,17 @@ test_that("regression functions work", { expect_equal(length(res$layers), 3) - res <- cal_plot_regression(boosting_predictions_oob, outcome, .pred, smooth = FALSE) + res <- cal_plot_regression( + boosting_predictions_oob, + outcome, + .pred, + smooth = FALSE + ) expect_s3_class(res, "ggplot") expect_equal( - res$data[0,], - dplyr::tibble(outcome = numeric(0), .pred = numeric(0), - id = character()) + res$data[0, ], + dplyr::tibble(outcome = numeric(0), .pred = numeric(0), id = character()) ) expect_equal( diff --git a/tests/testthat/test-cal-plot-windowed.R b/tests/testthat/test-cal-plot-windowed.R index ce77ac1..e9243ab 100644 --- a/tests/testthat/test-cal-plot-windowed.R +++ b/tests/testthat/test-cal-plot-windowed.R @@ -1,4 +1,3 @@ - test_that("Binary windowed functions work", { skip_if_not_installed("modeldata") @@ -11,18 +10,20 @@ test_that("Binary windowed functions work", { ) x30_1 <- segment_logistic |> - dplyr::mutate(x = dplyr::case_when( - .pred_good <= 0.05 ~ 1, - .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, - .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, - .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, - .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, - .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, - .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, - .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, - .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, - .pred_good >= 0.94 & .pred_good <= 1 ~ 10, - )) |> + dplyr::mutate( + x = dplyr::case_when( + .pred_good <= 0.05 ~ 1, + .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, + .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, + .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, + .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, + .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, + .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, + .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, + .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, + .pred_good >= 0.94 & .pred_good <= 1 ~ 10, + ) + ) |> dplyr::filter(!is.na(x)) |> dplyr::count(x) @@ -44,18 +45,20 @@ test_that("Binary windowed functions work", { x32_1 <- testthat_cal_binary() |> tune::collect_predictions(summarize = TRUE) |> - dplyr::mutate(x = dplyr::case_when( - .pred_class_1 <= 0.05 ~ 1, - .pred_class_1 >= 0.06 & .pred_class_1 <= 0.16 ~ 2, - .pred_class_1 >= 0.17 & .pred_class_1 <= 0.27 ~ 3, - .pred_class_1 >= 0.28 & .pred_class_1 <= 0.38 ~ 4, - .pred_class_1 >= 0.39 & .pred_class_1 <= 0.49 ~ 5, - .pred_class_1 >= 0.50 & .pred_class_1 <= 0.60 ~ 6, - .pred_class_1 >= 0.61 & .pred_class_1 <= 0.71 ~ 7, - .pred_class_1 >= 0.72 & .pred_class_1 <= 0.82 ~ 8, - .pred_class_1 >= 0.83 & .pred_class_1 <= 0.93 ~ 9, - .pred_class_1 >= 0.94 & .pred_class_1 <= 1 ~ 10, - )) |> + dplyr::mutate( + x = dplyr::case_when( + .pred_class_1 <= 0.05 ~ 1, + .pred_class_1 >= 0.06 & .pred_class_1 <= 0.16 ~ 2, + .pred_class_1 >= 0.17 & .pred_class_1 <= 0.27 ~ 3, + .pred_class_1 >= 0.28 & .pred_class_1 <= 0.38 ~ 4, + .pred_class_1 >= 0.39 & .pred_class_1 <= 0.49 ~ 5, + .pred_class_1 >= 0.50 & .pred_class_1 <= 0.60 ~ 6, + .pred_class_1 >= 0.61 & .pred_class_1 <= 0.71 ~ 7, + .pred_class_1 >= 0.72 & .pred_class_1 <= 0.82 ~ 8, + .pred_class_1 >= 0.83 & .pred_class_1 <= 0.93 ~ 9, + .pred_class_1 >= 0.94 & .pred_class_1 <= 1 ~ 10, + ) + ) |> dplyr::filter(!is.na(x)) |> dplyr::count(.config, x) @@ -70,10 +73,10 @@ test_that("Binary windowed functions work", { expect_true(has_facet(x33)) win_configs <- - bin_with_configs() |> cal_plot_windowed(truth = Class, estimate = .pred_good) + bin_with_configs() |> + cal_plot_windowed(truth = Class, estimate = .pred_good) expect_true(has_facet(win_configs)) - # ------------------------------------------------------------------------------ # multinomial outcome, binary windowed plots @@ -83,7 +86,6 @@ test_that("Binary windowed functions work", { # should be faceted by .config and class expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) - multi_configs_from_df <- mnl_with_configs() |> cal_plot_windowed(truth = obs, estimate = c(VF:L)) expect_s3_class(multi_configs_from_df, "ggplot") @@ -132,18 +134,20 @@ test_that("Groupings that may not match work", { ) x51_1 <- combined |> - dplyr::mutate(x = dplyr::case_when( - .pred_good <= 0.05 ~ 1, - .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, - .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, - .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, - .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, - .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, - .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, - .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, - .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, - .pred_good >= 0.94 & .pred_good <= 1 ~ 10, - )) |> + dplyr::mutate( + x = dplyr::case_when( + .pred_good <= 0.05 ~ 1, + .pred_good >= 0.06 & .pred_good <= 0.16 ~ 2, + .pred_good >= 0.17 & .pred_good <= 0.27 ~ 3, + .pred_good >= 0.28 & .pred_good <= 0.38 ~ 4, + .pred_good >= 0.39 & .pred_good <= 0.49 ~ 5, + .pred_good >= 0.50 & .pred_good <= 0.60 ~ 6, + .pred_good >= 0.61 & .pred_good <= 0.71 ~ 7, + .pred_good >= 0.72 & .pred_good <= 0.82 ~ 8, + .pred_good >= 0.83 & .pred_good <= 0.93 ~ 9, + .pred_good >= 0.94 & .pred_good <= 1 ~ 10, + ) + ) |> dplyr::filter(!is.na(x)) |> dplyr::count(source, x) diff --git a/tests/testthat/test-cal-validate-multiclass.R b/tests/testthat/test-cal-validate-multiclass.R index bc8c81b..7afc279 100644 --- a/tests/testthat/test-cal-validate-multiclass.R +++ b/tests/testthat/test-cal-validate-multiclass.R @@ -46,12 +46,28 @@ test_that("Isotonic validation with `fit_resamples` - Multiclass", { expect_equal(nrow(val_with_pred), nrow(res$multin)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") expect_equal( sort(names(val_with_pred$.predictions_cal[[1]])), - sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")) + sort(c( + ".pred_one", + ".pred_two", + ".pred_three", + ".row", + "outcome", + ".config", + ".pred_class" + )) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), diff --git a/tests/testthat/test-cal-validate.R b/tests/testthat/test-cal-validate.R index 31ea5a3..7ad1f78 100644 --- a/tests/testthat/test-cal-validate.R +++ b/tests/testthat/test-cal-validate.R @@ -3,7 +3,12 @@ test_that("Logistic validation with data frame input", { df <- testthat_cal_sampled() val_obj <- cal_validate_logistic(df, Class) - val_with_pred <- cal_validate_logistic(df, Class, save_pred = TRUE, smooth = TRUE) + val_with_pred <- cal_validate_logistic( + df, + Class, + save_pred = TRUE, + smooth = TRUE + ) expect_s3_class(val_obj, "data.frame") expect_s3_class(val_obj, "cal_rset") @@ -52,7 +57,12 @@ test_that("Logistic validation with data frame input", { expect_equal( names(pred_rs), c( - ".pred_class", ".pred_poor", ".pred_good", "Class", ".row", ".config", + ".pred_class", + ".pred_poor", + ".pred_good", + "Class", + ".row", + ".config", ".type" ) ) @@ -155,7 +165,12 @@ test_that("Multinomial classification validation with data frame input", { df <- rsample::vfold_cv(testthat_cal_sim_multi()) val_obj <- cal_validate_multinomial(df, class) - val_with_pred <- cal_validate_multinomial(df, class, save_pred = TRUE, smooth = TRUE) + val_with_pred <- cal_validate_multinomial( + df, + class, + save_pred = TRUE, + smooth = TRUE + ) expect_s3_class(val_obj, "data.frame") expect_s3_class(val_obj, "cal_rset") @@ -233,7 +248,12 @@ test_that("Validation without calibration with data frame input", { expect_equal( names(pred_rs), c( - ".pred_class", ".pred_poor", ".pred_good", "Class", ".row", ".config", + ".pred_class", + ".pred_poor", + ".pred_good", + "Class", + ".row", + ".config", ".type" ) ) @@ -245,7 +265,12 @@ test_that("Validation without calibration with data frame input", { test_that("Linear validation with data frame input", { df <- testthat_cal_reg_sampled() val_obj <- cal_validate_linear(df, outcome) - val_with_pred <- cal_validate_linear(df, outcome, save_pred = TRUE, smooth = FALSE) + val_with_pred <- cal_validate_linear( + df, + outcome, + save_pred = TRUE, + smooth = FALSE + ) expect_s3_class(val_obj, "data.frame") expect_s3_class(val_obj, "cal_rset") @@ -328,7 +353,11 @@ test_that("Bootstrapped Isotonic regression validation with data frame input", { test_that("Logistic validation with `fit_resamples`", { res <- testthat_cal_fit_rs() val_obj <- cal_validate_logistic(res$binary) - val_with_pred <- cal_validate_logistic(res$binary, save_pred = TRUE, smooth = TRUE) + val_with_pred <- cal_validate_logistic( + res$binary, + save_pred = TRUE, + smooth = TRUE + ) expect_s3_class(val_obj, "data.frame") expect_s3_class(val_obj, "cal_rset") @@ -343,13 +372,28 @@ test_that("Logistic validation with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$binary)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") expect_equal( sort(names(val_with_pred$.predictions_cal[[1]])), - sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) + sort(c( + ".pred_class_1", + ".pred_class_2", + ".row", + "outcome", + ".config", + ".pred_class" + )) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -375,13 +419,28 @@ test_that("Isotonic classification validation with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$binary)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") expect_equal( sort(names(val_with_pred$.predictions_cal[[1]])), - sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) + sort(c( + ".pred_class_1", + ".pred_class_2", + ".row", + "outcome", + ".config", + ".pred_class" + )) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -408,13 +467,28 @@ test_that("Bootstrapped isotonic classification validation with `fit_resamples`" expect_equal(nrow(val_with_pred), nrow(res$binary)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") expect_equal( sort(names(val_with_pred$.predictions_cal[[1]])), - sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) + sort(c( + ".pred_class_1", + ".pred_class_2", + ".row", + "outcome", + ".config", + ".pred_class" + )) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -441,13 +515,28 @@ test_that("Beta calibration validation with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$binary)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") expect_equal( sort(names(val_with_pred$.predictions_cal[[1]])), - sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) + sort(c( + ".pred_class_1", + ".pred_class_2", + ".row", + "outcome", + ".config", + ".pred_class" + )) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -461,7 +550,11 @@ test_that("Multinomial calibration validation with `fit_resamples`", { res <- testthat_cal_fit_rs() val_obj <- cal_validate_multinomial(res$multin) - val_with_pred <- cal_validate_multinomial(res$multin, save_pred = TRUE, smooth = TRUE) + val_with_pred <- cal_validate_multinomial( + res$multin, + save_pred = TRUE, + smooth = TRUE + ) expect_s3_class(val_obj, "data.frame") expect_s3_class(val_obj, "cal_rset") @@ -476,13 +569,29 @@ test_that("Multinomial calibration validation with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$multin)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") expect_equal( sort(names(val_with_pred$.predictions_cal[[1]])), - sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")) + sort(c( + ".pred_one", + ".pred_two", + ".pred_three", + ".row", + "outcome", + ".config", + ".pred_class" + )) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -508,13 +617,28 @@ test_that("Validation without calibration with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$binary)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") expect_equal( sort(names(val_with_pred$.predictions_cal[[1]])), - sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")) + sort(c( + ".pred_class_1", + ".pred_class_2", + ".row", + "outcome", + ".config", + ".pred_class" + )) ) expect_equal( purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)), @@ -543,7 +667,15 @@ test_that("Linear validation with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$reg)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") @@ -601,7 +733,11 @@ test_that("Isotonic regression validation with `fit_resamples`", { res <- testthat_cal_fit_rs() mtr <- yardstick::metric_set(yardstick::rmse, yardstick::rsq) val_obj <- cal_validate_isotonic(res$reg, estimate = .pred, metrics = mtr) - val_with_pred <- cal_validate_isotonic(res$reg, save_pred = TRUE, smooth = TRUE) + val_with_pred <- cal_validate_isotonic( + res$reg, + save_pred = TRUE, + smooth = TRUE + ) expect_s3_class(val_obj, "data.frame") expect_s3_class(val_obj, "cal_rset") @@ -616,7 +752,15 @@ test_that("Isotonic regression validation with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$reg)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") @@ -636,8 +780,16 @@ test_that("Isotonic regression validation with `fit_resamples`", { test_that("Isotonic bootstrapped regression validation with `fit_resamples`", { res <- testthat_cal_fit_rs() mtr <- yardstick::metric_set(yardstick::rmse, yardstick::rsq) - val_obj <- cal_validate_isotonic_boot(res$reg, estimate = .pred, metrics = mtr) - val_with_pred <- cal_validate_isotonic_boot(res$reg, save_pred = TRUE, smooth = TRUE) + val_obj <- cal_validate_isotonic_boot( + res$reg, + estimate = .pred, + metrics = mtr + ) + val_with_pred <- cal_validate_isotonic_boot( + res$reg, + save_pred = TRUE, + smooth = TRUE + ) expect_s3_class(val_obj, "data.frame") expect_s3_class(val_obj, "cal_rset") @@ -652,7 +804,15 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", { expect_equal(nrow(val_with_pred), nrow(res$reg)) expect_equal( names(val_with_pred), - c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") + c( + "splits", + "id", + ".notes", + ".predictions", + ".metrics", + ".metrics_cal", + ".predictions_cal" + ) ) skip_if_not_installed("tune", "1.2.0") diff --git a/tests/testthat/test-class-pred.R b/tests/testthat/test-class-pred.R index 6daef97..4e86984 100644 --- a/tests/testthat/test-class-pred.R +++ b/tests/testthat/test-class-pred.R @@ -152,7 +152,10 @@ test_that("casting factor to class_pred", { expect_equal(vec_cast(fc1, cp1), class_pred(factor(c("a", "b", "b", "b")))) # converting to ordered class_pred maintains orderedness - expect_equal(vec_cast(fc2, cp4), class_pred(factor(c("a", "b", "b", "c"), ordered = TRUE))) + expect_equal( + vec_cast(fc2, cp4), + class_pred(factor(c("a", "b", "b", "c"), ordered = TRUE)) + ) # convert to class_pred with NA already present is not lossy expect_warning(vec_cast(fc3, class_pred()), NA) @@ -160,7 +163,10 @@ test_that("casting factor to class_pred", { # convert ordered factor to class_pred # order-ness depends on class_pred type, not order factor or1 <- as.ordered(fc1) - expect_equal(vec_cast(or1, class_pred()), class_pred(factor(c("a", "b", "b", "b")))) + expect_equal( + vec_cast(or1, class_pred()), + class_pred(factor(c("a", "b", "b", "b"))) + ) }) test_that("casting character to class_pred", { @@ -217,8 +223,14 @@ test_that("unknown casts are handled correctly", { }) test_that("ptype2 checks are handled correctly", { - expect_error(vec_ptype2(manual_creation_eq, numeric()), class = "vctrs_error_incompatible_type") - expect_equal(vec_ptype2(manual_creation_eq, vctrs::unspecified()), vec_ptype(manual_creation_eq)) + expect_error( + vec_ptype2(manual_creation_eq, numeric()), + class = "vctrs_error_incompatible_type" + ) + expect_equal( + vec_ptype2(manual_creation_eq, vctrs::unspecified()), + vec_ptype(manual_creation_eq) + ) expect_equal(vec_ptype2(character(), manual_creation_eq), character()) expect_equal(vec_ptype2(manual_creation_eq, character()), character()) diff --git a/tests/testthat/test-conformal_infer_cv.R b/tests/testthat/test-conformal_infer_cv.R index 2544ac2..2f2960c 100644 --- a/tests/testthat/test-conformal_infer_cv.R +++ b/tests/testthat/test-conformal_infer_cv.R @@ -24,7 +24,6 @@ test_that("bad inputs to conformal intervals", { set.seed(182) sim_new <- sim_regression(2) - ctrl <- control_resamples(save_pred = TRUE, extract = I) set.seed(382) @@ -172,7 +171,10 @@ test_that("group resampling to conformal CV intervals", { set.seed(484) nnet_wflow <- - workflow(y ~ x, parsnip::mlp(hidden_units = 2) |> parsnip::set_mode("regression")) + workflow( + y ~ x, + parsnip::mlp(hidden_units = 2) |> parsnip::set_mode("regression") + ) group_folds <- group_vfold_cv(train_data, group = color) diff --git a/tests/testthat/test-conformal_infer_full.R b/tests/testthat/test-conformal_infer_full.R index 236c9c8..87ea378 100644 --- a/tests/testthat/test-conformal_infer_full.R +++ b/tests/testthat/test-conformal_infer_full.R @@ -24,7 +24,6 @@ test_that("bad inputs to conformal intervals", { set.seed(182) sim_new <- sim_regression(2) - ctrl <- control_resamples(save_pred = TRUE, extract = I) set.seed(382) @@ -101,7 +100,10 @@ test_that("bad inputs to conformal intervals", { ) expect_snapshot( - probably:::get_root(try(stop("I made you stop"), silent = TRUE), control_conformal_full()) + probably:::get_root( + try(stop("I made you stop"), silent = TRUE), + control_conformal_full() + ) ) }) @@ -136,14 +138,23 @@ test_that("conformal intervals", { sim_new <- sim_regression(2) ctrl_grid <- control_conformal_full(method = "grid", seed = 1) - basic_obj <- int_conformal_full(wflow, train_data = sim_data, control = ctrl_grid) + basic_obj <- int_conformal_full( + wflow, + train_data = sim_data, + control = ctrl_grid + ) ctrl_hard <- control_conformal_full( - progress = TRUE, seed = 1, - max_iter = 2, tolerance = 0.000001 + progress = TRUE, + seed = 1, + max_iter = 2, + tolerance = 0.000001 + ) + smol_obj <- int_conformal_full( + wflow_small, + train_data = sim_small, + control = ctrl_hard ) - smol_obj <- int_conformal_full(wflow_small, train_data = sim_small, control = ctrl_hard) - ctrl <- control_resamples(save_pred = TRUE, extract = I) set.seed(382) @@ -187,6 +198,8 @@ test_that("conformal control", { set.seed(1) expect_snapshot(dput(control_conformal_full())) expect_snapshot(dput(control_conformal_full(max_iter = 2))) - expect_snapshot(error = TRUE, control_conformal_full(method = "rock-paper-scissors")) + expect_snapshot( + error = TRUE, + control_conformal_full(method = "rock-paper-scissors") + ) }) - diff --git a/tests/testthat/test-make_class_pred.R b/tests/testthat/test-make_class_pred.R index 164e338..84964e1 100644 --- a/tests/testthat/test-make_class_pred.R +++ b/tests/testthat/test-make_class_pred.R @@ -20,7 +20,13 @@ test_that("two class succeeds with vector interface", { }) test_that("multi class succeeds with vector interface", { - res <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, min_prob = 0.5) + res <- make_class_pred( + bobcat, + coyote, + gray_fox, + levels = lvls2, + min_prob = 0.5 + ) fct <- factor(c("gray_fox", "gray_fox", "bobcat", "gray_fox", "coyote")) known <- class_pred(fct, which = 5) @@ -38,17 +44,34 @@ test_that("multi class succeeds with data frame helper", { name = "cp_test" ) - known <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, min_prob = 0.5) + known <- make_class_pred( + bobcat, + coyote, + gray_fox, + levels = lvls2, + min_prob = 0.5 + ) expect_s3_class(res, "data.frame") expect_equal(res[["cp_test"]], known) }) - test_that("ordered passes through to class_pred", { - res <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, ordered = TRUE) - res2 <- make_class_pred(bobcat, coyote, gray_fox, levels = lvls2, ordered = TRUE) + res <- make_class_pred( + bobcat, + coyote, + gray_fox, + levels = lvls2, + ordered = TRUE + ) + res2 <- make_class_pred( + bobcat, + coyote, + gray_fox, + levels = lvls2, + ordered = TRUE + ) expect_true(is_ordered_class_pred(res)) expect_true(is_ordered_class_pred(res2)) diff --git a/tests/testthat/test-threshold_perf.R b/tests/testthat/test-threshold_perf.R index bf3245b..402e04b 100644 --- a/tests/testthat/test-threshold_perf.R +++ b/tests/testthat/test-threshold_perf.R @@ -59,8 +59,14 @@ test_that("factor from numeric", { tab_1 <- table(new_fac_1) expect_s3_class(new_fac_1, "factor") expect_true(isTRUE(all.equal(levels(new_fac_1), levels(ex_data$outcome)))) - expect_equal(unname(tab_1["Cl1"]), sum(ex_data$prob_est >= ex_data$prob_est[1])) - expect_equal(unname(tab_1["Cl2"]), sum(ex_data$prob_est < ex_data$prob_est[1])) + expect_equal( + unname(tab_1["Cl1"]), + sum(ex_data$prob_est >= ex_data$prob_est[1]) + ) + expect_equal( + unname(tab_1["Cl2"]), + sum(ex_data$prob_est < ex_data$prob_est[1]) + ) # missing data new_fac_2 <- @@ -74,8 +80,14 @@ test_that("factor from numeric", { expect_s3_class(new_fac_2, "factor") cmpl_probs <- ex_data_miss$prob_est[!is.na(ex_data_miss$prob_est)] expect_true(isTRUE(all.equal(is.na(new_fac_2), is.na(ex_data_miss$prob_est)))) - expect_true(isTRUE(all.equal(levels(new_fac_2), levels(ex_data_miss$outcome)))) - expect_equal(unname(tab_2["Cl1"]), sum(cmpl_probs >= ex_data_miss$prob_est[1])) + expect_true(isTRUE(all.equal( + levels(new_fac_2), + levels(ex_data_miss$outcome) + ))) + expect_equal( + unname(tab_2["Cl1"]), + sum(cmpl_probs >= ex_data_miss$prob_est[1]) + ) expect_equal(unname(tab_2["Cl2"]), sum(cmpl_probs < ex_data_miss$prob_est[1])) new_fac_3 <- @@ -88,8 +100,14 @@ test_that("factor from numeric", { tab_3 <- table(new_fac_3) expect_s3_class(new_fac_3, "factor") expect_true(isTRUE(all.equal(levels(new_fac_3), levels(ex_data$outcome)))) - expect_equal(unname(tab_3["Cl1"]), sum(ex_data$prob_est < ex_data$prob_est[1])) - expect_equal(unname(tab_3["Cl2"]), sum(ex_data$prob_est >= ex_data$prob_est[1])) + expect_equal( + unname(tab_3["Cl1"]), + sum(ex_data$prob_est < ex_data$prob_est[1]) + ) + expect_equal( + unname(tab_3["Cl2"]), + sum(ex_data$prob_est >= ex_data$prob_est[1]) + ) }) test_that("single group", {