diff --git a/NAMESPACE b/NAMESPACE index 95291f6ef..b73110c9b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -57,6 +57,9 @@ S3method(req_pkgs,model_fit) S3method(req_pkgs,model_spec) S3method(required_pkgs,model_fit) S3method(required_pkgs,model_spec) +S3method(set_args,model_spec) +S3method(set_engine,model_spec) +S3method(set_mode,model_spec) S3method(tidy,"_LiblineaR") S3method(tidy,"_elnet") S3method(tidy,"_fishnet") diff --git a/R/arguments.R b/R/arguments.R index 5282e6909..5e6940942 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -49,6 +49,11 @@ check_eng_args <- function(args, obj, core_args) { #' #' @export set_args <- function(object, ...) { + UseMethod("set_args") +} + +#' @export +set_args.model_spec <- function(object, ...) { the_dots <- enquos(...) if (length(the_dots) == 0) rlang::abort("Please pass at least one named argument.") @@ -75,6 +80,11 @@ set_args <- function(object, ...) { #' @rdname set_args #' @export set_mode <- function(object, mode) { + UseMethod("set_mode") +} + +#' @export +set_mode.model_spec <- function(object, mode) { cls <- class(object)[1] if (rlang::is_missing(mode)) { spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) diff --git a/R/engines.R b/R/engines.R index ae16d2acf..8b299114e 100644 --- a/R/engines.R +++ b/R/engines.R @@ -105,10 +105,12 @@ load_libs <- function(x, quiet, attach = FALSE) { #' #' @export set_engine <- function(object, engine, ...) { + UseMethod("set_engine") +} + +#' @export +set_engine.model_spec <- function(object, engine, ...) { mod_type <- class(object)[1] - if (!inherits(object, "model_spec")) { - rlang::abort("`object` should have class 'model_spec'.") - } if (rlang::is_missing(engine)) { stop_missing_engine(mod_type) diff --git a/tests/testthat/_snaps/args_and_modes.md b/tests/testthat/_snaps/args_and_modes.md new file mode 100644 index 000000000..7f73f9342 --- /dev/null +++ b/tests/testthat/_snaps/args_and_modes.md @@ -0,0 +1,16 @@ +# set_* functions error when input isn't model_spec + + Code + set_mode(mtcars, "regression") + Condition + Error in `UseMethod()`: + ! no applicable method for 'set_mode' applied to an object of class "data.frame" + +--- + + Code + set_args(mtcars, blah = "blah") + Condition + Error in `UseMethod()`: + ! no applicable method for 'set_args' applied to an object of class "data.frame" + diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 8854d515a..19234031d 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -14,3 +14,11 @@ Computational engine: rpart +# set_engine works as a generic + + Code + set_engine(mtcars, "rpart") + Condition + Error in `UseMethod()`: + ! no applicable method for 'set_engine' applied to an object of class "data.frame" + diff --git a/tests/testthat/test_args_and_modes.R b/tests/testthat/test_args_and_modes.R index f8a63f319..5d42558bc 100644 --- a/tests/testthat/test_args_and_modes.R +++ b/tests/testthat/test_args_and_modes.R @@ -101,4 +101,13 @@ test_that("unavailable modes for an engine and vice-versa", { ) }) +test_that("set_* functions error when input isn't model_spec", { + expect_snapshot(error = TRUE, + set_mode(mtcars, "regression") + ) + + expect_snapshot(error = TRUE, + set_args(mtcars, blah = "blah") + ) +}) diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R index f8976096e..ed3a4745b 100644 --- a/tests/testthat/test_misc.R +++ b/tests/testthat/test_misc.R @@ -106,3 +106,9 @@ test_that('model type functions message informatively with unknown implementatio set_engine("rpart") ) }) + +test_that('set_engine works as a generic', { + expect_snapshot(error = TRUE, + set_engine(mtcars, "rpart") + ) +})