Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ importFrom(rlang,"%||%")
importFrom(rlang,":=")
importFrom(rlang,call2)
importFrom(rlang,call_name)
importFrom(rlang,caller_env)
importFrom(rlang,env_get)
importFrom(rlang,eval_tidy)
importFrom(rlang,expr)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* Added a new function, `compute_metrics()`, that allows for computing new metrics after evaluating against resamples. The arguments and output formats are closely related to those from `collect_metrics()`, but this function requires that the input be generated with the control option `save_pred = TRUE` and additionally takes a `metrics` argument with a metric set for new metrics to compute. This allows for computing new performance metrics without requiring users to re-fit and re-predict from each model. (#663)

* Improved error message when needed packages aren't installed. (#727)

* `last_fit()` will now error when supplied a fitted workflow. (#678)

* A method for rsample's `int_pctl()` function that will compute percentile confidence intervals on performance metrics for objects produced by `fit_resamples()`, `tune_*()`, and `last_fit()`.
Expand Down
2 changes: 1 addition & 1 deletion R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @importFrom purrr map_int
#' @importFrom rlang call2 ns_env is_quosure is_quosures quo_get_expr call_name
#' @importFrom rlang is_false eval_tidy expr sym syms env_get is_function :=
#' @importFrom rlang is_missing %||%
#' @importFrom rlang is_missing %||% caller_env
#' @importFrom glue glue glue_collapse
#' @importFrom dials is_unknown encode_unit
#' @importFrom stats sd qt qnorm dnorm pnorm predict model.matrix setNames
Expand Down
19 changes: 11 additions & 8 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ check_backend_options <- function(backend_options) {

grid_msg <- "`grid` should be a positive integer or a data frame."

check_grid <- function(grid, workflow, pset = NULL) {
check_grid <- function(grid, workflow, pset = NULL, call = caller_env()) {
# `NULL` grid is the signal that we are using `fit_resamples()`
if (is.null(grid)) {
return(grid)
Expand Down Expand Up @@ -111,7 +111,7 @@ check_grid <- function(grid, workflow, pset = NULL) {
if (grid < 1) {
rlang::abort(grid_msg)
}
check_workflow(workflow, pset = pset, check_dials = TRUE)
check_workflow(workflow, pset = pset, check_dials = TRUE, call = call)

grid <- dials::grid_latin_hypercube(pset, size = grid)
grid <- dplyr::distinct(grid)
Expand Down Expand Up @@ -188,7 +188,7 @@ is_installed <- function(pkg) {
res
}

check_installs <- function(x) {
check_installs <- function(x, call = caller_env()) {
if (x$engine == "unknown") {
rlang::abort("Please declare an engine for the model")
} else {
Expand All @@ -201,9 +201,12 @@ check_installs <- function(x) {
if (length(deps) > 0) {
is_inst <- purrr::map_lgl(deps, is_installed)
if (any(!is_inst)) {
rlang::abort(c("Some package installs are required: ",
paste0("'", deps[!is_inst], "'", collapse = ", ")
))
needs_installed <- unique(deps[!is_inst])
cli::cli_abort(
"{cli::qty(needs_installed)} Package install{?s} {?is/are} \\
required for {.pkg {needs_installed}}.",
call = call
)
}
}
}
Expand Down Expand Up @@ -271,7 +274,7 @@ check_param_objects <- function(pset) {
#' @keywords internal
#' @rdname empty_ellipses
#' @param check_dials A logical for check for a NULL parameter object.
check_workflow <- function(x, pset = NULL, check_dials = FALSE) {
check_workflow <- function(x, pset = NULL, check_dials = FALSE, call = caller_env()) {
if (!inherits(x, "workflow")) {
rlang::abort("The `object` argument should be a 'workflow' object.")
}
Expand Down Expand Up @@ -303,7 +306,7 @@ check_workflow <- function(x, pset = NULL, check_dials = FALSE) {

check_extra_tune_parameters(x)

check_installs(hardhat::extract_spec_parsnip(x))
check_installs(hardhat::extract_spec_parsnip(x), call = call)

invisible(NULL)
}
Expand Down
5 changes: 3 additions & 2 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ fit_resamples.workflow <- function(object,
# ------------------------------------------------------------------------------

resample_workflow <- function(workflow, resamples, metrics, control,
eval_time = NULL, rng) {
eval_time = NULL, rng, call = caller_env()) {
check_no_tuning(workflow)

# `NULL` is the signal that we have no grid to tune with
Expand All @@ -150,7 +150,8 @@ resample_workflow <- function(workflow, resamples, metrics, control,
pset = pset,
control = control,
eval_time = eval_time,
rng = rng
rng = rng,
call = call
)

attributes <- attributes(out)
Expand Down
6 changes: 4 additions & 2 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ tune_bayes.workflow <-
tune_bayes_workflow <-
function(object, resamples, iter = 10, param_info = NULL, metrics = NULL,
objective = exp_improve(),
initial = 5, control = control_bayes(), eval_time = NULL, ...) {
initial = 5, control = control_bayes(), eval_time = NULL, ...,
call = caller_env()) {
start_time <- proc.time()[3]

initialize_catalog(control = control)
Expand All @@ -278,7 +279,8 @@ tune_bayes_workflow <-
if (is.null(param_info)) {
param_info <- hardhat::extract_parameter_set_dials(object)
}
check_workflow(object, check_dials = is.null(param_info), pset = param_info)
check_workflow(object, check_dials = is.null(param_info), pset = param_info,
call = call)
check_backend_options(control$backend_options)

unsummarized <- check_initial(
Expand Down
5 changes: 3 additions & 2 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ tune_grid_workflow <- function(workflow,
pset = NULL,
control = control_grid(),
eval_time = NULL,
rng = TRUE) {
rng = TRUE,
call = caller_env()) {
check_rset(resamples)


Expand All @@ -344,7 +345,7 @@ tune_grid_workflow <- function(workflow,
grid_names = names(grid)
)

check_workflow(workflow, pset = pset)
check_workflow(workflow, pset = pset, call = call)
check_backend_options(control$backend_options)

grid <- check_grid(
Expand Down
2 changes: 1 addition & 1 deletion man/empty_ellipses.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/checks.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@
Error in `tune:::check_workflow()`:
! A parsnip model is required.

# errors informatively when needed package isn't installed

Code
check_workflow(stan_wflow)
Condition
Error:
! Package install is required for rstanarm.

---

Code
fit_resamples(stan_wflow, rsample::bootstraps(mtcars))
Condition
Error in `fit_resamples()`:
! Package install is required for rstanarm.

# workflow objects (will not tune, tidymodels/tune#548)

Code
Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/test-checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,23 @@ test_that("workflow objects", {
})
})

test_that("errors informatively when needed package isn't installed", {
# rstanarm is not installed during CI runs
# in contexts where it _is_ installed, skip the test.
skip_if(rlang::is_installed("rstanarm"))
stan_wflow <- workflow(mpg ~ ., parsnip::linear_reg(engine = "stan"))

expect_snapshot(
check_workflow(stan_wflow),
error = TRUE
)

expect_snapshot(
fit_resamples(stan_wflow, rsample::bootstraps(mtcars)),
error = TRUE
)
})

test_that("workflow objects (will not tune, tidymodels/tune#548)", {
skip_if_not_installed("glmnet")

Expand Down