Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions R/tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
#' with the [tidymodels](https://tidymodels.org) framework; for greatest ease
#' of use, situate tailors in model workflows with `?workflows::add_tailor()`.
#'
#' @param type Character. The model sub-mode. Possible values are
#' `"unknown"`, `"regression"`, `"binary"`, or `"multiclass"`. Only required
#' when used independently of `?workflows::add_tailor()`.
#' @param outcome <[`tidy-select`][dplyr::dplyr_tidy_select]> Only required
#' when used independently of `?workflows::add_tailor()`, and can also be passed
#' at `fit()` time instead. The column name of the outcome variable.
Expand Down Expand Up @@ -64,18 +61,16 @@
#' # adjust hard class predictions
#' predict(tlr_fit, two_class_example) %>% count(predicted)
#' @export
tailor <- function(type = "unknown", outcome = NULL, estimate = NULL,
probabilities = NULL) {
tailor <- function(outcome = NULL, estimate = NULL, probabilities = NULL) {
columns <-
list(
outcome = outcome,
type = type,
estimate = estimate,
probabilities = probabilities
)

new_tailor(
type,
"unknown",
adjustments = list(),
columns = columns,
ptype = tibble::new_tibble(list()),
Expand All @@ -84,8 +79,6 @@ tailor <- function(type = "unknown", outcome = NULL, estimate = NULL,
}

new_tailor <- function(type, adjustments, columns, ptype, call) {
type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass"))

if (!is.list(adjustments)) {
cli_abort("The {.arg adjustments} argument should be a list.", call = call)
}
Expand All @@ -97,8 +90,14 @@ new_tailor <- function(type, adjustments, columns, ptype, call) {
{.val adjustment}: {bad_adjustment}.", call = call)
}

orderings <- adjustment_orderings(adjustments)

if (type == "unknown") {
type <- infer_type(orderings)
}

# validate adjustment order and check duplicates
validate_order(adjustments, type, call)
validate_order(orderings, type, call)

# check columns
res <- list(
Expand Down Expand Up @@ -233,5 +232,5 @@ set_tailor_type <- function(object, y) {
# todo setup eval_time
# todo missing methods:
# todo tune_args
# todo tidy
# todo tidy (this should probably just be `adjustment_orderings()`)
# todo extract_parameter_set_dials
13 changes: 13 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ tailor_adjustment_requires_fit <- function(x) {
isTRUE(x$requires_fit)
}

# an tidy-esque method for adjustment lists, used in validating
# compatibility of adjustments
adjustment_orderings <- function(adjustments) {
tibble::new_tibble(list(
name = purrr::map_chr(adjustments, ~ class(.x)[1]),
input = purrr::map_chr(adjustments, ~ .x$inputs),
output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)),
output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)),
output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)),
output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs))
))
}

# ad-hoc checking --------------------------------------------------------------
check_tailor <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
if (!is_tailor(x)) {
Expand Down
36 changes: 21 additions & 15 deletions R/validation-rules.R
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
validate_order <- function(adjustments, type, call = caller_env()) {
orderings <-
tibble::new_tibble(list(
name = purrr::map_chr(adjustments, ~ class(.x)[1]),
input = purrr::map_chr(adjustments, ~ .x$inputs),
output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)),
output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)),
output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)),
output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs))
))

if (length(adjustments) < 2) {
validate_order <- function(orderings, type, call = caller_env()) {
if (nrow(orderings) < 2) {
return(invisible(orderings))
}

if (type == "unknown") {
type <- infer_type(orderings)
}
check_incompatible_types(orderings, call)

switch(
type,
Expand All @@ -27,6 +15,24 @@ validate_order <- function(adjustments, type, call = caller_env()) {
invisible(orderings)
}

check_incompatible_types <- function(orderings, call) {
if (all(c("numeric", "probability") %in% orderings$input)) {
numeric_adjustments <- orderings$name[which(orderings$input == "numeric")]
probability_adjustments <- orderings$name[which(orderings$input == "probability")]
cli_abort(
c(
"Can't compose adjustments for different prediction types.",
"i" = "{cli::qty(numeric_adjustments)}
Adjustment{?s} {.fn {paste0('adjust_', numeric_adjustments)}}
{cli::qty(numeric_adjustments[-1])} operate{?s} on numerics while
{.fn {paste0('adjust_', probability_adjustments)}}
{cli::qty(probability_adjustments[-1])} operate{?s} on probabilities."
),
call = call
)
}
}

check_classification_order <- function(x, call) {
cal_ind <- which(grepl("calibration$", x$name))
eq_ind <- which(grepl("equivocal", x$name))
Expand Down
6 changes: 1 addition & 5 deletions man/tailor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/_snaps/adjust-equivocal-zone.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:

* Add equivocal zone of size 0.1.

Expand All @@ -16,7 +16,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:

* Add equivocal zone of optimized size.

12 changes: 2 additions & 10 deletions tests/testthat/_snaps/adjust-numeric-calibration.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:

* Re-calibrate numeric predictions.

Expand All @@ -20,15 +20,7 @@
---

Code
tailor("binary") %>% adjust_numeric_calibration("linear")
Condition
Error in `adjust_numeric_calibration()`:
! A binary tailor is incompatible with the adjustment `adjust_numeric_calibration()`.

---

Code
tailor("regression") %>% adjust_numeric_calibration("binary")
tailor() %>% adjust_numeric_calibration("binary")
Condition
Error in `adjust_numeric_calibration()`:
! `method` must be one of "linear", "isotonic", or "isotonic_boot", not "binary".
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/adjust-numeric-range.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:

* Constrain numeric predictions to be between [-Inf, Inf].

Expand All @@ -16,7 +16,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:

* Constrain numeric predictions to be between [?, Inf].

Expand All @@ -27,7 +27,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:

* Constrain numeric predictions to be between [-1, ?].

Expand All @@ -38,7 +38,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A regression postprocessor with 1 adjustment:

* Constrain numeric predictions to be between [?, 1].

12 changes: 2 additions & 10 deletions tests/testthat/_snaps/adjust-probability-calibration.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:

* Re-calibrate classification probabilities.

Expand All @@ -20,15 +20,7 @@
---

Code
tailor("regression") %>% adjust_probability_calibration("binary")
Condition
Error in `adjust_probability_calibration()`:
! A regression tailor is incompatible with the adjustment `adjust_probability_calibration()`.

---

Code
tailor("binary") %>% adjust_probability_calibration("linear")
tailor() %>% adjust_probability_calibration("linear")
Condition
Error in `adjust_probability_calibration()`:
! `method` must be one of "logistic", "multinomial", "beta", "isotonic", or "isotonic_boot", not "linear".
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/adjust-probability-threshold.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:

* Adjust probability threshold to 0.5.

Expand All @@ -16,7 +16,7 @@
Message

-- tailor ----------------------------------------------------------------------
A postprocessor with 1 adjustment:
A binary postprocessor with 1 adjustment:

* Adjust probability threshold to optimized value.

9 changes: 4 additions & 5 deletions tests/testthat/_snaps/tailor.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
---

Code
tailor(type = "binary")
tailor()
Message

-- tailor ----------------------------------------------------------------------
A binary postprocessor with 0 adjustments.
A postprocessor with 0 adjustments.

---

Code
tailor(type = "binary") %>% adjust_probability_threshold(0.2)
tailor() %>% adjust_probability_threshold(0.2)
Message

-- tailor ----------------------------------------------------------------------
Expand All @@ -30,8 +30,7 @@
---

Code
tailor(type = "binary") %>% adjust_probability_threshold(0.2) %>%
adjust_equivocal_zone()
tailor() %>% adjust_probability_threshold(0.2) %>% adjust_equivocal_zone()
Message

-- tailor ----------------------------------------------------------------------
Expand Down
Loading
Loading