Skip to content

Commit

Permalink
changes for #144 (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Jun 5, 2024
1 parent e601d1c commit 29d9fb3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* A new function `bound_prediction()` is available to constrain the values of a numeric prediction (#142).

* Bug fix for `cal_plot_breaks()` with binary classification with custom probability column names (#144).

* Fixed an error in `int_conformal_cv()` when grouped resampling was used (#141).

# probably 1.0.3
Expand Down
13 changes: 9 additions & 4 deletions R/cal-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ truth_estimate_map <- function(.data, truth, estimate, validate = FALSE) {
truth_str <- names(truth_str)
}

# Get the name(s) of the column(s) that have the predicted values. For binary
# data, this is a single column name.
estimate_str <- .data %>%
tidyselect_cols({{ estimate }}) %>%
names()
Expand All @@ -21,6 +23,8 @@ truth_estimate_map <- function(.data, truth, estimate, validate = FALSE) {

truth_levels <- levels(.data[[truth_str]])

# `est_map` maps the levels of the outcome to the corresponding column(s) in
# the data
if (length(truth_levels) > 0) {
if (all(substr(estimate_str, 1, 6) == ".pred_")) {
est_map <- purrr::map(
Expand All @@ -33,10 +37,11 @@ truth_estimate_map <- function(.data, truth, estimate, validate = FALSE) {
}
)
} else {
est_map <- purrr::map(
seq_along(truth_levels),
~ sym(estimate_str[[.x]])
)
if (length(estimate_str) == 1) {
est_map <- list(sym(estimate_str), NULL)
} else {
est_map <- purrr::map(seq_along(truth_levels), ~ sym(estimate_str[[.x]]))
}
}
if (validate) {
check_level_consistency(truth_levels, est_map)
Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test-cal-plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,10 @@ test_that("don't facet if there is only one .config", {
expect_null(res_regression$data[[".config"]])
expect_s3_class(res_regression, "ggplot")
})

test_that("custom names for cal_plot_breaks()", {
data(segment_logistic)
segment_logistic_1 <- dplyr::rename(segment_logistic, good_prob = .pred_good)
p <- cal_plot_breaks(segment_logistic_1, Class, good_prob)
expect_s3_class(p, "ggplot")
})

0 comments on commit 29d9fb3

Please sign in to comment.