diff --git a/NAMESPACE b/NAMESPACE index b73110c9b..0cc003a17 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -57,8 +57,11 @@ S3method(req_pkgs,model_fit) S3method(req_pkgs,model_spec) S3method(required_pkgs,model_fit) S3method(required_pkgs,model_spec) +S3method(set_args,default) S3method(set_args,model_spec) +S3method(set_engine,default) S3method(set_engine,model_spec) +S3method(set_mode,default) S3method(set_mode,model_spec) S3method(tidy,"_LiblineaR") S3method(tidy,"_elnet") diff --git a/R/aaa_models.R b/R/aaa_models.R index 13bfedee0..613e7ac39 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -92,6 +92,22 @@ set_env_val <- function(name, value) { # ------------------------------------------------------------------------------ +error_set_object <- function(object, func) { + msg <- + "`{func}()` expected a model specification to be supplied to the \ + `object` argument, but received a(n) `{class(object)[1]}` object." + + if (inherits(object, "function") && + isTRUE(environment(object)$.packageName == "parsnip")) { + msg <- c( + msg, + "i" = "Did you mistakenly pass `model_function` rather than `model_function()`?" + ) + } + + cli::cli_abort(msg, call = call2(func)) +} + check_eng_val <- function(eng) { if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) rlang::abort("Please supply a character string for an engine name (e.g. `'lm'`)") diff --git a/R/arguments.R b/R/arguments.R index 5e6940942..0d71eb20e 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -77,6 +77,13 @@ set_args.model_spec <- function(object, ...) { ) } +#' @export +set_args.default <- function(object,...) { + error_set_object(object, func = "set_args") + + invisible(FALSE) +} + #' @rdname set_args #' @export set_mode <- function(object, mode) { @@ -95,6 +102,13 @@ set_mode.model_spec <- function(object, mode) { object } +#' @export +set_mode.default <- function(object, mode) { + error_set_object(object, func = "set_mode") + + invisible(FALSE) +} + # ------------------------------------------------------------------------------ maybe_eval <- function(x) { diff --git a/R/engines.R b/R/engines.R index 8b299114e..b59704b6d 100644 --- a/R/engines.R +++ b/R/engines.R @@ -136,6 +136,13 @@ set_engine.model_spec <- function(object, engine, ...) { ) } +#' @export +set_engine.default <- function(object, engine, ...) { + error_set_object(object, func = "set_engine") + + invisible(FALSE) +} + #' Display currently available engines for a model #' #' The possible engines for a model can depend on what packages are loaded. diff --git a/tests/testthat/_snaps/args_and_modes.md b/tests/testthat/_snaps/args_and_modes.md index 7f73f9342..2d4f5dd4a 100644 --- a/tests/testthat/_snaps/args_and_modes.md +++ b/tests/testthat/_snaps/args_and_modes.md @@ -3,14 +3,58 @@ Code set_mode(mtcars, "regression") Condition - Error in `UseMethod()`: - ! no applicable method for 'set_mode' applied to an object of class "data.frame" + Error in `set_mode()`: + ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. --- Code set_args(mtcars, blah = "blah") Condition - Error in `UseMethod()`: - ! no applicable method for 'set_args' applied to an object of class "data.frame" + Error in `set_args()`: + ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. + +--- + + Code + bag_tree %>% set_mode("classification") + Condition + Error in `set_mode()`: + ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? + +--- + + Code + bag_tree %>% set_engine("rpart") + Condition + Error in `set_engine()`: + ! `set_engine()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? + +--- + + Code + bag_tree %>% set_args(boop = "bop") + Condition + Error in `set_args()`: + ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? + +--- + + Code + 1L %>% set_args(mode = "classification") + Condition + Error in `set_args()`: + ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `integer` object. + +--- + + Code + bag_tree %>% set_mode("classification") + Condition + Error in `set_mode()`: + ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 19234031d..c580ac0fb 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -19,6 +19,6 @@ Code set_engine(mtcars, "rpart") Condition - Error in `UseMethod()`: - ! no applicable method for 'set_engine' applied to an object of class "data.frame" + Error in `set_engine()`: + ! `set_engine()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. diff --git a/tests/testthat/test_args_and_modes.R b/tests/testthat/test_args_and_modes.R index 5d42558bc..4a8f51038 100644 --- a/tests/testthat/test_args_and_modes.R +++ b/tests/testthat/test_args_and_modes.R @@ -109,5 +109,28 @@ test_that("set_* functions error when input isn't model_spec", { expect_snapshot(error = TRUE, set_args(mtcars, blah = "blah") ) + + expect_snapshot(error = TRUE, + bag_tree %>% set_mode("classification") + ) + + expect_snapshot(error = TRUE, + bag_tree %>% set_engine("rpart") + ) + + expect_snapshot(error = TRUE, + bag_tree %>% set_args(boop = "bop") + ) + + # won't raise "info" part of error if not a parsnip-namespaced function + # not a function + expect_snapshot(error = TRUE, + 1L %>% set_args(mode = "classification") + ) + + # not from parsnip + expect_snapshot(error = TRUE, + bag_tree %>% set_mode("classification") + ) })