diff --git a/DESCRIPTION b/DESCRIPTION
index c7c333099..5215d6dc5 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,6 +1,6 @@
Package: parsnip
-Version: 0.0.0.9004
-Title: A Common API to Modeling and analysis Functions
+Version: 0.0.0.9005
+Title: A Common API to Modeling and Analysis Functions
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. R, spark, stan, etc).
Authors@R: c(
person("Max", "Kuhn", , "max@rstudio.com", c("aut", "cre")),
diff --git a/NAMESPACE b/NAMESPACE
index 88b02f030..7d4305157 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -102,6 +102,7 @@ export(predict_raw)
export(predict_raw.model_fit)
export(rand_forest)
export(set_args)
+export(set_engine)
export(set_mode)
export(show_call)
export(surv_reg)
diff --git a/NEWS.md b/NEWS.md
index b8bfad6f6..2aacae55e 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,3 +1,8 @@
+# parsnip 0.0.0.9005
+
+* The engine, and any associated arguments, are not specified using `set_engine`. There is no `engine` argument
+
+
# parsnip 0.0.0.9004
* Arguments to modeling functions are now captured as quosures.
diff --git a/R/arguments.R b/R/arguments.R
index 5c3f7d8f0..4db44be42 100644
--- a/R/arguments.R
+++ b/R/arguments.R
@@ -50,7 +50,7 @@ prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) {
x
}
-check_others <- function(args, obj, core_args) {
+check_eng_args <- function(args, obj, core_args) {
# Make sure that we are not trying to modify an argument that
# is explicitly protected in the method metadata or arg_key
protected_args <- unique(c(obj$protect, core_args))
@@ -95,10 +95,17 @@ set_args <- function(object, ...) {
if (any(main_args == i)) {
object$args[[i]] <- the_dots[[i]]
} else {
- object$others[[i]] <- the_dots[[i]]
+ object$eng_args[[i]] <- the_dots[[i]]
}
}
- object
+ new_model_spec(
+ cls = class(object)[1],
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
#' @rdname set_args
@@ -130,6 +137,6 @@ maybe_eval <- function(x) {
eval_args <- function(spec, ...) {
spec$args <- purrr::map(spec$args, maybe_eval)
- spec$others <- purrr::map(spec$others, maybe_eval)
+ spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)
spec
}
diff --git a/R/boost_tree.R b/R/boost_tree.R
index 61f2d0f0a..f196d4e8e 100644
--- a/R/boost_tree.R
+++ b/R/boost_tree.R
@@ -22,7 +22,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using the `set_engine` function. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
@@ -46,11 +46,6 @@
#' @param sample_size An number for the number (or proportion) of data that is
#' exposed to the fitting routine. For `xgboost`, the sampling is done at at
#' each iteration while `C5.0` samples once during traning.
-#' @param ... Other arguments to pass to the specific engine's
-#' model fit function (see the Engine Details section below). This
-#' should not include arguments defined by the main parameters to
-#' this function. For the `update` function, the ellipses can
-#' contain the primary arguments or any others.
#' @details
#' The data given to the function are not saved and are only used
#' to determine the _mode_ of the model. For `boost_tree`, the
@@ -63,17 +58,12 @@
#' \item \pkg{Spark}: `"spark"`
#' }
#'
-#' Main parameter arguments (and those in `...`) can avoid
-#' evaluation until the underlying function is executed by wrapping the
-#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`).
-#'
#'
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
-#' model, the template of the fit calls are:
+#' model fit call. For this type of model, the template of the
+#' fit calls are:
#'
#' \pkg{xgboost} classification
#'
@@ -109,7 +99,7 @@
#' reloaded and reattached to the `parsnip` object.
#'
#' @importFrom purrr map_lgl
-#' @seealso [varying()], [fit()]
+#' @seealso [varying()], [fit()], [set_engine()]
#' @examples
#' boost_tree(mode = "classification", trees = 20)
#' # Parameters can be represented by a placeholder:
@@ -121,11 +111,7 @@ boost_tree <-
mtry = NULL, trees = NULL, min_n = NULL,
tree_depth = NULL, learn_rate = NULL,
loss_reduction = NULL,
- sample_size = NULL,
- ...) {
-
- others <- enquos(...)
-
+ sample_size = NULL) {
args <- list(
mtry = enquo(mtry),
trees = enquo(trees),
@@ -136,18 +122,14 @@ boost_tree <-
sample_size = enquo(sample_size)
)
- if (!(mode %in% boost_tree_modes))
- stop("`mode` should be one of: ",
- paste0("'", boost_tree_modes, "'", collapse = ", "),
- call. = FALSE)
-
- no_value <- !vapply(others, null_value, logical(1))
- others <- others[no_value]
-
- out <- list(args = args, others = others,
- mode = mode, method = NULL, engine = NULL)
- class(out) <- make_classes("boost_tree")
- out
+ new_model_spec(
+ "boost_tree",
+ args,
+ eng_args = NULL,
+ mode,
+ method = NULL,
+ engine = NULL
+ )
}
#' @export
@@ -167,6 +149,7 @@ print.boost_tree <- function(x, ...) {
#' @export
#' @inheritParams boost_tree
#' @param object A boosted tree model specification.
+#' @param ... Not used for `update`.
#' @param fresh A logical for whether the arguments should be
#' modified in-place of or replaced wholesale.
#' @return An updated model specification.
@@ -183,10 +166,8 @@ update.boost_tree <-
mtry = NULL, trees = NULL, min_n = NULL,
tree_depth = NULL, learn_rate = NULL,
loss_reduction = NULL, sample_size = NULL,
- fresh = FALSE,
- ...) {
-
- others <- enquos(...)
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
mtry = enquo(mtry),
@@ -209,23 +190,27 @@ update.boost_tree <-
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "boost_tree",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
# ------------------------------------------------------------------------------
#' @export
-translate.boost_tree <- function(x, engine, ...) {
+translate.boost_tree <- function(x, engine = x$engine, ...) {
+ if (is.null(engine)) {
+ message("Used `engine = 'xgboost'` for translation.")
+ engine <- "xgboost"
+ }
x <- translate.default(x, engine, ...)
- if (x$engine == "spark") {
+ if (engine == "spark") {
if (x$mode == "unknown")
stop(
"For spark boosted trees models, the mode cannot be 'unknown' ",
diff --git a/R/convert_data.R b/R/convert_data.R
index 50398db26..dbd6603cf 100644
--- a/R/convert_data.R
+++ b/R/convert_data.R
@@ -76,7 +76,7 @@ convert_form_to_xy_fit <-function(
if (indicators) {
x <- model.matrix(mod_terms, mod_frame, contrasts)
} else {
- # this still ignores -vars in formula ¯\_(ツ)_/¯
+ # this still ignores -vars in formula
x <- model.frame(mod_terms, data)
y_cols <- attr(mod_terms, "response")
if (length(y_cols) > 0)
diff --git a/R/descriptors.R b/R/descriptors.R
index 9ff68f0df..52b17aa69 100644
--- a/R/descriptors.R
+++ b/R/descriptors.R
@@ -318,11 +318,11 @@ make_descr <- function(object) {
expr_main <- map_lgl(object$args, has_exprs)
else
expr_main <- FALSE
- if (length(object$others) > 0)
- expr_others <- map_lgl(object$others, has_exprs)
+ if (length(object$eng_args) > 0)
+ expr_eng_args <- map_lgl(object$eng_args, has_exprs)
else
- expr_others <- FALSE
- any(expr_main) | any(expr_others)
+ expr_eng_args <- FALSE
+ any(expr_main) | any(expr_eng_args)
}
# Locate descriptors -----------------------------------------------------------
@@ -331,7 +331,7 @@ make_descr <- function(object) {
requires_descrs <- function(object) {
any(c(
map_lgl(object$args, has_any_descrs),
- map_lgl(object$others, has_any_descrs)
+ map_lgl(object$eng_args, has_any_descrs)
))
}
diff --git a/R/engines.R b/R/engines.R
index 013a1ba55..28a15ac5e 100644
--- a/R/engines.R
+++ b/R/engines.R
@@ -52,3 +52,43 @@ check_installs <- function(x) {
}
}
}
+
+#' Declare a computational engine and specific arguments
+#'
+#' `set_engine` is used to specify which package or system will be used
+#' to fit the model, along with any arguments specific to that software.
+#'
+#' @param object A model specification.
+#' @param engine A character string for the software that should
+#' be used to fit the model. This is highly dependent on the type
+#' of model (e.g. linear regression, random forest, etc.).
+#' @param ... Any optional arguments associated with the chosen computational
+#' engine. These are captured as quosures and can be `varying()`.
+#' @return An updated model specification.
+#' @examples
+#' # First, set general arguments using the standardized names
+#' mod <-
+#' logistic_reg(mixture = 1/3) %>%
+#' # now say how you want to fit the model and another other options
+#' set_engine("glmnet", nlambda = 10)
+#' translate(mod, engine = "glmnet")
+#' @export
+set_engine <- function(object, engine, ...) {
+ if (!inherits(object, "model_spec")) {
+ stop("`object` should have class 'model_spec'.", call. = FALSE)
+ }
+ if (!is.character(engine) | length(engine) != 1)
+ stop("`engine` should be a single character value.", call. = FALSE)
+
+ object$engine <- engine
+ object <- check_engine(object)
+
+ new_model_spec(
+ cls = class(object)[1],
+ args = object$args,
+ eng_args = enquos(...),
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
+}
diff --git a/R/fit.R b/R/fit.R
index 4f240545a..ec201da36 100644
--- a/R/fit.R
+++ b/R/fit.R
@@ -9,7 +9,8 @@
#' code by substituting arguments, and execute the model fit
#' routine.
#'
-#' @param object An object of class `model_spec`
+#' @param object An object of class `model_spec` that has a chosen engine
+#' (via [set_engine()]).
#' @param formula An object of class "formula" (or one that can
#' be coerced to that class): a symbolic description of the model
#' to be fitted.
@@ -17,15 +18,11 @@
#' below). A data frame containing all relevant variables (e.g.
#' outcome(s), predictors, case weights, etc). Note: when needed, a
#' \emph{named argument} should be used.
-#' @param engine A character string for the software that should
-#' be used to fit the model. This is highly dependent on the type
-#' of model (e.g. linear regression, random forest, etc.).
#' @param control A named list with elements `verbosity` and
#' `catch`. See [fit_control()].
#' @param ... Not currently used; values passed here will be
#' ignored. Other options required to fit the model should be
-#' passed using the `others` argument in the original model
-#' specification.
+#' passed using `set_engine`.
#' @details `fit` and `fit_xy` substitute the current arguments in the model
#' specification into the computational engine's code, checks them
#' for validity, then fits the model using the data and the
@@ -49,21 +46,20 @@
#' library(dplyr)
#' data("lending_club")
#'
-#' lm_mod <- logistic_reg()
+#' lr_mod <- logistic_reg()
#'
-#' lm_mod <- logistic_reg()
+#' lr_mod <- logistic_reg()
#'
#' using_formula <-
-#' lm_mod %>%
-#' fit(Class ~ funded_amnt + int_rate,
-#' data = lending_club,
-#' engine = "glm")
+#' lr_mod %>%
+#' set_engine("glm") %>%
+#' fit(Class ~ funded_amnt + int_rate, data = lending_club)
#'
#' using_xy <-
-#' lm_mod %>%
+#' lr_mod %>%
+#' set_engine("glm") %>%
#' fit_xy(x = lending_club[, c("funded_amnt", "int_rate")],
-#' y = lending_club$Class,
-#' engine = "glm")
+#' y = lending_club$Class)
#'
#' using_formula
#' using_xy
@@ -83,6 +79,7 @@
#' The return value will also have a class related to the fitted model (e.g.
#' `"_glm"`) before the base class of `"model_fit"`.
#'
+#' @seealso [set_engine()], [fit_control()], `model_spec`, `model_fit`
#' @param x A matrix or data frame of predictors.
#' @param y A vector, matrix or data frame of outcome data.
#' @rdname fit
@@ -92,11 +89,13 @@ fit.model_spec <-
function(object,
formula = NULL,
data = NULL,
- engine = object$engine,
control = fit_control(),
...
) {
dots <- quos(...)
+ if (any(names(dots) == "engine"))
+ stop("Use `set_engine` to supply the engine.", call. = FALSE)
+
if (all(c("x", "y") %in% names(dots)))
stop("`fit.model_spec` is for the formula methods. Use `fit_xy` instead.",
call. = FALSE)
@@ -109,10 +108,8 @@ fit.model_spec <-
eval_env$formula <- formula
fit_interface <-
check_interface(eval_env$formula, eval_env$data, cl, object)
- object$engine <- engine
- object <- check_engine(object)
- if (engine == "spark" && !inherits(eval_env$data, "tbl_spark"))
+ if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark"))
stop(
"spark objects can only be used with the formula interface to `fit` ",
"with a spark data object.", call. = FALSE
@@ -122,7 +119,7 @@ fit.model_spec <-
object <- get_method(object, engine = object$engine)
check_installs(object) # TODO rewrite with pkgman
- # TODO Should probably just load the namespace
+
load_libs(object, control$verbosity < 2)
interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")
@@ -178,20 +175,20 @@ fit_xy.model_spec <-
function(object,
x = NULL,
y = NULL,
- engine = object$engine,
control = fit_control(),
...
) {
+ dots <- quos(...)
+ if (any(names(dots) == "engine"))
+ stop("Use `set_engine` to supply the engine.", call. = FALSE)
cl <- match.call(expand.dots = TRUE)
eval_env <- rlang::env()
eval_env$x <- x
eval_env$y <- y
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
- object$engine <- engine
- object <- check_engine(object)
- if (engine == "spark")
+ if (object$engine == "spark")
stop(
"spark objects can only be used with the formula interface to `fit` ",
"with a spark data object.", call. = FALSE
@@ -201,7 +198,7 @@ fit_xy.model_spec <-
object <- get_method(object, engine = object$engine)
check_installs(object) # TODO rewrite with pkgman
- # TODO Should probably just load the namespace
+
load_libs(object, control$verbosity < 2)
interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_")
diff --git a/R/linear_reg.R b/R/linear_reg.R
index f2e37817f..e0805d288 100644
--- a/R/linear_reg.R
+++ b/R/linear_reg.R
@@ -12,7 +12,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using `set_engine`. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
@@ -25,7 +25,6 @@
#' represents the proportion of regularization that is used for the
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
#' (the lasso) (`glmnet` and `spark` only).
-#'
#' @details
#' The data given to the function are not saved and are only used
#' to determine the _mode_ of the model. For `linear_reg`, the
@@ -42,8 +41,7 @@
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{lm}
@@ -92,7 +90,7 @@
#' separately saved to disk. In a new session, the object can be
#' reloaded and reattached to the `parsnip` object.
#'
-#' @seealso [varying()], [fit()]
+#' @seealso [varying()], [fit()], [set_engine()]
#' @examples
#' linear_reg()
#' # Parameters can be represented by a placeholder:
@@ -102,36 +100,21 @@
linear_reg <-
function(mode = "regression",
penalty = NULL,
- mixture = NULL,
- ...) {
-
- others <- enquos(...)
+ mixture = NULL) {
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
)
- if (!(mode %in% linear_reg_modes))
- stop(
- "`mode` should be one of: ",
- paste0("'", linear_reg_modes, "'", collapse = ", "),
- call. = FALSE
- )
-
- no_value <- !vapply(others, is.null, logical(1))
- others <- others[no_value]
-
- # write a constructor function
- out <- list(
+ new_model_spec(
+ "linear_reg",
args = args,
- others = others,
+ eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
)
- class(out) <- make_classes("linear_reg")
- out
}
#' @export
@@ -162,11 +145,8 @@ print.linear_reg <- function(x, ...) {
update.linear_reg <-
function(object,
penalty = NULL, mixture = NULL,
- fresh = FALSE,
- ...) {
-
- others <- enquos(...)
-
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
@@ -182,14 +162,14 @@ update.linear_reg <-
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "linear_reg",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
# ------------------------------------------------------------------------------
diff --git a/R/logistic_reg.R b/R/logistic_reg.R
index 29fb60bf3..a0d67f0c1 100644
--- a/R/logistic_reg.R
+++ b/R/logistic_reg.R
@@ -12,7 +12,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using `set_engine`. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
@@ -39,8 +39,7 @@
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{glm}
@@ -100,36 +99,21 @@
logistic_reg <-
function(mode = "classification",
penalty = NULL,
- mixture = NULL,
- ...) {
-
- others <- enquos(...)
+ mixture = NULL) {
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
)
- if (!(mode %in% logistic_reg_modes))
- stop(
- "`mode` should be one of: ",
- paste0("'", logistic_reg_modes, "'", collapse = ", "),
- call. = FALSE
- )
-
- no_value <- !vapply(others, is.null, logical(1))
- others <- others[no_value]
-
- # write a constructor function
- out <- list(
+ new_model_spec(
+ "logistic_reg",
args = args,
- others = others,
+ eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
)
- class(out) <- make_classes("logistic_reg")
- out
}
#' @export
@@ -160,11 +144,8 @@ print.logistic_reg <- function(x, ...) {
update.logistic_reg <-
function(object,
penalty = NULL, mixture = NULL,
- fresh = FALSE,
- ...) {
-
- others <- enquos(...)
-
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
@@ -180,14 +161,14 @@ update.logistic_reg <-
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "logistic_reg",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
# ------------------------------------------------------------------------------
diff --git a/R/mars.R b/R/mars.R
index 6bc57b482..7835eb05b 100644
--- a/R/mars.R
+++ b/R/mars.R
@@ -17,7 +17,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using `set_engine`. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
@@ -30,11 +30,7 @@
#' final model, including the intercept.
#' @param prod_degree The highest possible interaction degree.
#' @param prune_method The pruning method.
-#' @details Main parameter arguments (and those in `...`) can avoid
-#' evaluation until the underlying function is executed by wrapping the
-#' argument in [rlang::expr()].
-#'
-#' The model can be created using the `fit()` function using the
+#' @details The model can be created using the `fit()` function using the
#' following _engines_:
#' \itemize{
#' \item \pkg{R}: `"earth"`
@@ -43,8 +39,7 @@
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{earth} classification
@@ -67,10 +62,7 @@
mars <-
function(mode = "unknown",
- num_terms = NULL, prod_degree = NULL, prune_method = NULL,
- ...) {
-
- others <- enquos(...)
+ num_terms = NULL, prod_degree = NULL, prune_method = NULL) {
args <- list(
num_terms = enquo(num_terms),
@@ -78,18 +70,14 @@ mars <-
prune_method = enquo(prune_method)
)
- if (!(mode %in% mars_modes))
- stop("`mode` should be one of: ",
- paste0("'", mars_modes, "'", collapse = ", "),
- call. = FALSE)
-
- no_value <- !vapply(others, is.null, logical(1))
- others <- others[no_value]
-
- out <- list(args = args, others = others,
- mode = mode, method = NULL, engine = NULL)
- class(out) <- make_classes("mars")
- out
+ new_model_spec(
+ "mars",
+ args = args,
+ eng_args = NULL,
+ mode = mode,
+ method = NULL,
+ engine = NULL
+ )
}
#' @export
@@ -120,11 +108,8 @@ print.mars <- function(x, ...) {
update.mars <-
function(object,
num_terms = NULL, prod_degree = NULL, prune_method = NULL,
- fresh = FALSE,
- ...) {
-
- others <- enquos(...)
-
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
num_terms = enquo(num_terms),
prod_degree = enquo(prod_degree),
@@ -141,26 +126,29 @@ update.mars <-
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "mars",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
# ------------------------------------------------------------------------------
#' @export
-translate.mars <- function(x, engine, ...) {
-
+translate.mars <- function(x, engine = x$engine, ...) {
+ if (is.null(engine)) {
+ message("Used `engine = 'earth'` for translation.")
+ engine <- "earth"
+ }
# If classification is being done, the `glm` options should be used. Check to
# see if it is there and, if not, add the default value.
if (x$mode == "classification") {
- if (!("glm" %in% names(x$others))) {
- x$others$glm <- quote(list(family = stats::binomial))
+ if (!("glm" %in% names(x$eng_args))) {
+ x$eng_args$glm <- quote(list(family = stats::binomial))
}
}
@@ -182,7 +170,7 @@ check_args.mars <- function(object) {
if (!is_varying(args$prune_method) &&
!is.null(args$prune_method) &&
- is.character(args$prune_method))
+ !is.character(args$prune_method))
stop("`prune_method` should be a single string value", call. = FALSE)
invisible(object)
@@ -223,11 +211,20 @@ multi_predict._earth <-
num_terms <- sort(num_terms)
+ # update.earth uses the values in the call so evaluate them if
+ # they are quosures
+ call_names <- names(object$fit$call)
+ call_names <- call_names[!(call_names %in% c("", "x", "y"))]
+ for (i in call_names) {
+ if (is_quosure(object$fit$call[[i]]))
+ object$fit$call[[i]] <- eval_tidy(object$fit$call[[i]])
+ }
+
msg <-
paste("Please use `keepxy = TRUE` as an option to enable submodel",
"predictions with `earth`.")
- if (any(names(object$spec$others) == "keepxy")) {
- if(!object$spec$others$keepxy)
+ if (any(names(object$fit$call) == "keepxy")) {
+ if(!isTRUE(object$fit$call$keepxy))
stop (msg, call. = FALSE)
} else
stop (msg, call. = FALSE)
diff --git a/R/misc.R b/R/misc.R
index 5748cae92..5c80cca64 100644
--- a/R/misc.R
+++ b/R/misc.R
@@ -18,7 +18,7 @@ make_classes <- function(prefix) {
check_empty_ellipse <- function (...) {
terms <- quos(...)
if (!is_empty(terms))
- stop("Please pass other arguments to the model function via `others`", call. = FALSE)
+ stop("Please pass other arguments to the model function via `set_engine`", call. = FALSE)
terms
}
@@ -35,7 +35,6 @@ deparserizer <- function(x, limit = options()$width - 10) {
}
print_arg_list <- function(x, ...) {
- others <- c("name", "call", "expression")
atomic <- vapply(x, is.atomic, logical(1))
x2 <- x
x2[!atomic] <- lapply(x2[!atomic], deparserizer, ...)
@@ -59,10 +58,10 @@ model_printer <- function(x, ...) {
non_null_args <- map(non_null_args, convert_arg)
cat(print_arg_list(non_null_args), "\n", sep = "")
}
- if (length(x$others) > 0) {
+ if (length(x$eng_args) > 0) {
cat("Engine-Specific Arguments:\n")
- x$others <- map(x$others, convert_arg)
- cat(print_arg_list(x$others), "\n", sep = "")
+ x$eng_args <- map(x$eng_args, convert_arg)
+ cat(print_arg_list(x$eng_args), "\n", sep = "")
}
if (!is.null(x$engine)) {
cat("Computational engine:", x$engine, "\n\n")
@@ -190,3 +189,30 @@ names0 <- function (num, prefix = "x") {
ind <- gsub(" ", "0", ind)
paste0(prefix, ind)
}
+
+
+# ------------------------------------------------------------------------------
+
+update_dot_check <- function(...) {
+ dots <- enquos(...)
+ if (length(dots) > 0)
+ stop("Extra arguments will be ignored: ",
+ paste0("`", names(dots), "`", collapse = ", "),
+ call. = FALSE)
+ invisible(NULL)
+}
+
+# ------------------------------------------------------------------------------
+
+new_model_spec <- function(cls, args, eng_args, mode, method, engine) {
+ spec_modes <- get(paste0(cls, "_modes"))
+ if (!(mode %in% spec_modes))
+ stop("`mode` should be one of: ",
+ paste0("'", spec_modes, "'", collapse = ", "),
+ call. = FALSE)
+
+ out <- list(args = args, eng_args = eng_args,
+ mode = mode, method = method, engine = engine)
+ class(out) <- make_classes(cls)
+ out
+}
diff --git a/R/mlp.R b/R/mlp.R
index a323b89c2..8706a46b6 100644
--- a/R/mlp.R
+++ b/R/mlp.R
@@ -18,7 +18,7 @@
#'
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using `set_engine`. If left to their defaults
#' here (see above), the values are taken from the underlying model
#' functions. One exception is `hidden_units` when `nnet::nnet` is used; that
#' function's `size` argument has no default so a value of 5 units will be
@@ -51,18 +51,13 @@
#' \item \pkg{keras}: `"keras"`
#' }
#'
-#' Main parameter arguments (and those in `...`) can avoid
-#' evaluation until the underlying function is executed by wrapping the
-#' argument in [rlang::expr()] (e.g. `hidden_units = expr(num_preds * 2)`).
-#'
#' An error is thrown if both `penalty` and `dropout` are specified for
#' `keras` models.
#'
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{keras} classification
@@ -92,10 +87,7 @@
mlp <-
function(mode = "unknown",
hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL,
- activation = NULL,
- ...) {
-
- others <- enquos(...)
+ activation = NULL) {
args <- list(
hidden_units = enquo(hidden_units),
@@ -105,20 +97,14 @@ mlp <-
activation = enquo(activation)
)
- if (!(mode %in% mlp_modes))
- stop("`mode` should be one of: ",
- paste0("'", mlp_modes, "'", collapse = ", "),
- call. = FALSE)
-
- no_value <- !vapply(others, is.null, logical(1))
- others <- others[no_value]
-
- # write a constructor function
- out <- list(args = args, others = others,
- mode = mode, method = NULL, engine = NULL)
- # TODO: make_classes has wrong order; go from specific to general
- class(out) <- make_classes("mlp")
- out
+ new_model_spec(
+ "mlp",
+ args = args,
+ eng_args = NULL,
+ mode = mode,
+ method = NULL,
+ engine = NULL
+ )
}
#' @export
@@ -155,10 +141,8 @@ update.mlp <-
function(object,
hidden_units = NULL, penalty = NULL, dropout = NULL,
epochs = NULL, activation = NULL,
- fresh = FALSE,
- ...) {
- others <- enquos(...)
-
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
hidden_units = enquo(hidden_units),
penalty = enquo(penalty),
@@ -178,20 +162,24 @@ update.mlp <-
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "mlp",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
# ------------------------------------------------------------------------------
#' @export
-translate.mlp <- function(x, engine, ...) {
+translate.mlp <- function(x, engine = x$engine, ...) {
+ if (is.null(engine)) {
+ message("Used `engine = 'keras'` for translation.")
+ engine <- "keras"
+ }
if (engine == "nnet") {
if(isTRUE(is.null(quo_get_expr(x$args$hidden_units)))) {
@@ -203,10 +191,10 @@ translate.mlp <- function(x, engine, ...) {
if (engine == "nnet") {
if (x$mode == "classification") {
- if (length(x$others) == 0 || !any(names(x$others) == "linout"))
+ if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout"))
x$method$fit$args$linout <- FALSE
} else {
- if (length(x$others) == 0 || !any(names(x$others) == "linout"))
+ if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout"))
x$method$fit$args$linout <- TRUE
}
}
@@ -219,10 +207,6 @@ check_args.mlp <- function(object) {
args <- lapply(object$args, rlang::eval_tidy)
- if (is.numeric(args$hidden_units))
- if (args$hidden_units < 2)
- stop("There must be at least two hidden units", call. = FALSE)
-
if (is.numeric(args$penalty))
if (args$penalty < 0)
stop("The amount of weight decay must be >= 0.", call. = FALSE)
diff --git a/R/model_object_docs.R b/R/model_object_docs.R
index ed563f788..af46bc0e8 100644
--- a/R/model_object_docs.R
+++ b/R/model_object_docs.R
@@ -175,10 +175,12 @@ NULL
#' @examples
#'
#' # Keep the `x` matrix if the data are not too big.
-#' spec_obj <- linear_reg(x = ifelse(.obs() < 500, TRUE, FALSE))
+#' spec_obj <-
+#' linear_reg() %>%
+#' set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE))
#' spec_obj
#'
-#' fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars, engine = "lm")
+#' fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars)
#' fit_obj
#'
#' nrow(fit_obj$fit$x)
diff --git a/R/multinom_reg.R b/R/multinom_reg.R
index d9505cf57..6f6a41b43 100644
--- a/R/multinom_reg.R
+++ b/R/multinom_reg.R
@@ -12,7 +12,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using `set_engine`. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
@@ -38,8 +38,7 @@
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{glmnet}
@@ -83,35 +82,21 @@
multinom_reg <-
function(mode = "classification",
penalty = NULL,
- mixture = NULL,
- ...) {
- others <- enquos(...)
+ mixture = NULL) {
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
)
- if (!(mode %in% multinom_reg_modes))
- stop(
- "`mode` should be one of: ",
- paste0("'", multinom_reg_modes, "'", collapse = ", "),
- call. = FALSE
- )
-
- no_value <- !vapply(others, is.null, logical(1))
- others <- others[no_value]
-
- # write a constructor function
- out <- list(
+ new_model_spec(
+ "multinom_reg",
args = args,
- others = others,
+ eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
)
- class(out) <- make_classes("multinom_reg")
- out
}
#' @export
@@ -142,10 +127,8 @@ print.multinom_reg <- function(x, ...) {
update.multinom_reg <-
function(object,
penalty = NULL, mixture = NULL,
- fresh = FALSE,
- ...) {
- others <- enquos(...)
-
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
@@ -161,14 +144,14 @@ update.multinom_reg <-
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "multinom_reg",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
# ------------------------------------------------------------------------------
diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R
index 8b374b7f6..b85c16a9c 100644
--- a/R/nearest_neighbor.R
+++ b/R/nearest_neighbor.R
@@ -19,7 +19,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using `set_engine`. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
@@ -49,8 +49,7 @@
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{kknn} (classification or regression)
@@ -67,38 +66,27 @@
#' @seealso [varying()], [fit()]
#'
#' @examples
-#' nearest_neighbor()
+#' nearest_neighbor(neighbors = 11)
#'
#' @export
nearest_neighbor <- function(mode = "unknown",
neighbors = NULL,
weight_func = NULL,
- dist_power = NULL,
- ...) {
- others <- enquos(...)
-
+ dist_power = NULL) {
args <- list(
neighbors = enquo(neighbors),
weight_func = enquo(weight_func),
dist_power = enquo(dist_power)
)
- ## TODO: make a utility function here
- if (!(mode %in% nearest_neighbor_modes)) {
- stop("`mode` should be one of: ",
- paste0("'", nearest_neighbor_modes, "'", collapse = ", "),
- call. = FALSE)
- }
-
- no_value <- !vapply(others, is.null, logical(1))
- others <- others[no_value]
-
- # write a constructor function
- out <- list(args = args, others = others,
- mode = mode, method = NULL, engine = NULL)
- # TODO: make_classes has wrong order; go from specific to general
- class(out) <- make_classes("nearest_neighbor")
- out
+ new_model_spec(
+ "nearest_neighbor",
+ args = args,
+ eng_args = NULL,
+ mode = mode,
+ method = NULL,
+ engine = NULL
+ )
}
#' @export
@@ -121,11 +109,8 @@ update.nearest_neighbor <- function(object,
neighbors = NULL,
weight_func = NULL,
dist_power = NULL,
- fresh = FALSE,
- ...) {
-
- others <- enquos(...)
-
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
neighbors = enquo(neighbors),
weight_func = enquo(weight_func),
@@ -142,14 +127,14 @@ update.nearest_neighbor <- function(object,
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "nearest_neighbor",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
diff --git a/R/predict.R b/R/predict.R
index 5dfd42823..ea7ea7149 100644
--- a/R/predict.R
+++ b/R/predict.R
@@ -47,7 +47,7 @@
#'
#' Quantile predictions return a tibble with a column `.pred`, which is
#' a list-column. Each list element contains a tibble with columns
-#' `.pred` and `.quantile` (and perhaps others).
+#' `.pred` and `.quantile` (and perhaps other columns).
#'
#' Using `type = "raw"` with `predict.model_fit` (or using
#' `predict_raw`) will return the unadulterated results of the
@@ -63,7 +63,8 @@
#'
#' lm_model <-
#' linear_reg() %>%
-#' fit(mpg ~ ., data = mtcars %>% slice(11:32), engine = "lm")
+#' set_engine("lm") %>%
+#' fit(mpg ~ ., data = mtcars %>% slice(11:32))
#'
#' pred_cars <-
#' mtcars %>%
diff --git a/R/rand_forest.R b/R/rand_forest.R
index 3d81e897b..4dc26ea5d 100644
--- a/R/rand_forest.R
+++ b/R/rand_forest.R
@@ -15,7 +15,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to their defaults
+#' set using `set_engine`. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
@@ -38,15 +38,10 @@
#' \item \pkg{Spark}: `"spark"`
#' }
#'
-#' Main parameter arguments (and those in `...`) can avoid
-#' evaluation until the underlying function is executed by wrapping the
-#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`).
-#'
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are::
#'
#' \pkg{ranger} classification
@@ -100,10 +95,7 @@
#' @export
rand_forest <-
- function(mode = "unknown",
- mtry = NULL, trees = NULL, min_n = NULL, ...) {
-
- others <- enquos(...)
+ function(mode = "unknown", mtry = NULL, trees = NULL, min_n = NULL) {
args <- list(
mtry = enquo(mtry),
@@ -111,21 +103,14 @@ rand_forest <-
min_n = enquo(min_n)
)
- ## TODO: make a utility function here
- if (!(mode %in% rand_forest_modes))
- stop("`mode` should be one of: ",
- paste0("'", rand_forest_modes, "'", collapse = ", "),
- call. = FALSE)
-
- no_value <- !vapply(others, null_value, logical(1))
- others <- others[no_value]
-
- # write a constructor function
- out <- list(args = args, others = others,
- mode = mode, method = NULL, engine = NULL)
- # TODO: make_classes has wrong order; go from specific to general
- class(out) <- make_classes("rand_forest")
- out
+ new_model_spec(
+ "rand_forest",
+ args = args,
+ eng_args = NULL,
+ mode = mode,
+ method = NULL,
+ engine = NULL
+ )
}
#' @export
@@ -156,10 +141,8 @@ print.rand_forest <- function(x, ...) {
update.rand_forest <-
function(object,
mtry = NULL, trees = NULL, min_n = NULL,
- fresh = FALSE,
- ...) {
- others <- enquos(...)
-
+ fresh = FALSE, ...) {
+ update_dot_check(...)
args <- list(
mtry = enquo(mtry),
trees = enquo(trees),
@@ -177,20 +160,25 @@ update.rand_forest <-
object$args[names(args)] <- args
}
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+ new_model_spec(
+ "rand_forest",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
}
# ------------------------------------------------------------------------------
#' @export
-translate.rand_forest <- function(x, engine, ...) {
+translate.rand_forest <- function(x, engine = x$engine, ...) {
+ if (is.null(engine)) {
+ message("Used `engine = 'ranger'` for translation.")
+ engine <- "ranger"
+ }
+
x <- translate.default(x, engine, ...)
# slightly cleaner code using
@@ -217,7 +205,7 @@ translate.rand_forest <- function(x, engine, ...) {
}
# add checks to error trap or change things for this method
- if (x$engine == "ranger") {
+ if (engine == "ranger") {
if (any(names(arg_vals) == "importance"))
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance))))
stop("`importance` should be a character value. See ?ranger::ranger.",
diff --git a/R/surv_reg.R b/R/surv_reg.R
index 29c3489ab..65c86b416 100644
--- a/R/surv_reg.R
+++ b/R/surv_reg.R
@@ -9,7 +9,7 @@
#' }
#' This argument is converted to its specific names at the
#' time that the model is fit. Other options and argument can be
-#' set using the `...` slot. If left to its default
+#' set using `set_engine`. If left to its default
#' here (`NULL`), the value is taken from the underlying model
#' functions.
#'
@@ -42,8 +42,7 @@
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
-#' model fit call. These can be changed by using the `...`
-#' argument to pass in the preferred values. For this type of
+#' model fit call. For this type of
#' model, the template of the fit calls are:
#'
#' \pkg{flexsurv}
@@ -67,36 +66,20 @@
#' surv_reg(dist = varying())
#'
#' @export
-surv_reg <-
- function(mode = "regression",
- dist = NULL,
- ...) {
- others <- enquos(...)
+surv_reg <- function(mode = "regression", dist = NULL) {
args <- list(
dist = enquo(dist)
)
- if (!(mode %in% surv_reg_modes))
- stop(
- "`mode` should be one of: ",
- paste0("'", surv_reg_modes, "'", collapse = ", "),
- call. = FALSE
- )
-
- no_value <- !vapply(others, is.null, logical(1))
- others <- others[no_value]
-
- # write a constructor function
- out <- list(
+ new_model_spec(
+ "surv_reg",
args = args,
- others = others,
+ eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
)
- class(out) <- make_classes("surv_reg")
- out
}
#' @export
@@ -128,42 +111,41 @@ print.surv_reg <- function(x, ...) {
#' @method update surv_reg
#' @rdname surv_reg
#' @export
-update.surv_reg <-
- function(object,
- dist = NULL,
- fresh = FALSE,
- ...) {
- others <- enquos(...)
-
- args <- list(
- dist = enquo(dist)
- )
-
- if (fresh) {
- object$args <- args
- } else {
- null_args <- map_lgl(args, null_value)
- if (any(null_args))
- args <- args[!null_args]
- if (length(args) > 0)
- object$args[names(args)] <- args
- }
-
- if (length(others) > 0) {
- if (fresh)
- object$others <- others
- else
- object$others[names(others)] <- others
- }
-
- object
+update.surv_reg <- function(object, dist = NULL, fresh = FALSE, ...) {
+ update_dot_check(...)
+ args <- list(
+ dist = enquo(dist)
+ )
+
+ if (fresh) {
+ object$args <- args
+ } else {
+ null_args <- map_lgl(args, null_value)
+ if (any(null_args))
+ args <- args[!null_args]
+ if (length(args) > 0)
+ object$args[names(args)] <- args
}
+ new_model_spec(
+ "surv_reg",
+ args = object$args,
+ eng_args = object$eng_args,
+ mode = object$mode,
+ method = NULL,
+ engine = object$engine
+ )
+}
+
# ------------------------------------------------------------------------------
#' @export
-translate.surv_reg <- function(x, engine, ...) {
+translate.surv_reg <- function(x, engine = x$engine, ...) {
+ if (is.null(engine)) {
+ message("Used `engine = 'survreg'` for translation.")
+ engine <- "survreg"
+ }
x <- translate.default(x, engine, ...)
x
}
diff --git a/R/translate.R b/R/translate.R
index 7f1102c57..7c88e8c18 100644
--- a/R/translate.R
+++ b/R/translate.R
@@ -42,8 +42,10 @@ translate <- function (x, ...)
#' @importFrom utils getFromNamespace
#' @importFrom purrr list_modify
#' @export
-translate.default <- function(x, engine, ...) {
+translate.default <- function(x, engine = x$engine, ...) {
check_empty_ellipse(...)
+ if (is.null(engine))
+ stop("Please set an engine.", call. = FALSE)
x$engine <- engine
x <- check_engine(x)
@@ -60,7 +62,7 @@ translate.default <- function(x, engine, ...) {
# expression unless there are dots, warn if protected args are
# being altered
eng_arg_key <- arg_key[[x$engine]]
- x$others <- check_others(x$others, x$method$fit, eng_arg_key)
+ x$eng_args <- check_eng_args(x$eng_args, x$method$fit, eng_arg_key)
# keep only modified args
modifed_args <- !vapply(actual_args, null_value, lgl(1))
@@ -68,21 +70,20 @@ translate.default <- function(x, engine, ...) {
# look for defaults if not modified in other
if(length(x$method$fit$defaults) > 0) {
- in_other <- names(x$method$fit$defaults) %in% names(x$others)
+ in_other <- names(x$method$fit$defaults) %in% names(x$eng_args)
x$defaults <- x$method$fit$defaults[!in_other]
}
- # combine primary, others, and defaults
+ # combine primary, eng_args, and defaults
protected <- lapply(x$method$fit$protect, function(x) expr(missing_arg()))
names(protected) <- x$method$fit$protect
- x$method$fit$args <- c(protected, actual_args, x$others, x$defaults)
+ x$method$fit$args <- c(protected, actual_args, x$eng_args, x$defaults)
- # put in correct order
x
}
-get_method <- function(x, engine, ...) {
+get_method <- function(x, engine = x$engine, ...) {
check_empty_ellipse(...)
x$engine <- engine
x <- check_engine(x)
diff --git a/R/varying.R b/R/varying.R
index 49f50eb55..501faa444 100644
--- a/R/varying.R
+++ b/R/varying.R
@@ -23,18 +23,20 @@ varying <- function()
#'
#' rand_forest(mtry = varying()) %>% varying_args(id = "one arg")
#'
-#' rand_forest(others = list(sample.fraction = varying())) %>%
-#' varying_args(id = "only others")
+#' rand_forest() %>%
+#' set_engine("ranger", sample.fraction = varying()) %>%
+#' varying_args(id = "only eng_args")
#'
-#' rand_forest(
-#' others = list(
-#' strata = expr(Class),
+#' rand_forest() %>%
+#' set_engine(
+#' "ranger",
+#' strata = expr(Class),
#' sampsize = c(varying(), varying())
-#' )
-#' ) %>%
-#' varying_args(id = "add an expr")
+#' ) %>%
+#' varying_args(id = "add an expr")
#'
-#' rand_forest(others = list(classwt = c(class1 = 1, class2 = varying()))) %>%
+#' rand_forest() %>%
+#' set_engine("ranger", classwt = c(class1 = 1, class2 = varying())) %>%
#' varying_args(id = "list of values")
#'
#' @export
@@ -55,8 +57,8 @@ varying_args.model_spec <- function(x, id = NULL, ...) {
if (is.null(id))
id <- deparse(cl$x)
varying_args <- map(x$args, find_varying)
- varying_others <- map(x$others, find_varying)
- res <- c(varying_args, varying_others)
+ varying_eng_args <- map(x$eng_args, find_varying)
+ res <- c(varying_args, varying_eng_args)
res <- map_lgl(res, any)
tibble(
name = names(res),
diff --git a/_pkgdown.yml b/_pkgdown.yml
index 971b84395..cf5e1e4e3 100644
--- a/_pkgdown.yml
+++ b/_pkgdown.yml
@@ -36,6 +36,7 @@ reference:
- model_spec
- predict.model_fit
- set_args
+ - set_engine
- set_mode
- translate
- varying
diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html
index ed83c3d28..6d3e67216 100644
--- a/docs/articles/articles/Classification.html
+++ b/docs/articles/articles/Classification.html
@@ -40,7 +40,7 @@
parsnip
A single hidden layer neural network will be used to predict a person’s credit status. To do so, the columns of the predictor matrix should be numeric and on a common scale. recipes will be used to do so.
keras will be used to fit a model with 5 hidden units and uses a 10% dropout rate to regularize the model. At each training iteration (aka epoch) a random 20% of the data will be used to measure the cross-entropy of the model.
In parsnip, the predict function is only appropriate for numeric outcomes while predict_class and predict_classprob can be used for categorical outcomes.
The non-formula interface doesn’t do anything to the predictors before giving it to the underlying model function. This particular model does not require indicator variables to be create prior to the model (note that the output shows “Number of independent variables: 5”).
For regression models, the basic predict method can be used and returns a tibble with a column named .pred:
Suppose that there was some feature in the randomForest package that we’d like to evaluate. To do so, the only part of the syntaxt that needs to change is the engine argument:
+
Suppose that there was some feature in the randomForest package that we’d like to evaluate. To do so, the only part of the syntaxt that needs to change is the set_engine argument:
Look at the formula code that was printed out, one function uses the argument name ntree and the other uses num.trees. parsnip doesn’t require you to know the specific names of the main arguments.
Now suppose that we want to modify the value of mtry based on the number of predictors in the data. Usually, the default value would be floor(sqrt(num_predictors)). To use a pure bagging model would require an mtry value equal to the total number of parameters. There may be cases where you may not know how many predictors are going to be present (perhaps due to the generation of indicator variables or a variable filter) so that might be difficult to know exactly.
-
When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.
+
When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.
Two relevant descriptors for what we are about to do are:
@@ -257,10 +257,10 @@
Since ranger won’t create indicator values, .preds() would be appropriate for using mtry for a bagging model.
For example, let’s use an expression with the .preds() descriptor to fit a bagging model:
If penalty were not specified, all of the lambda values would be computed.
To get the predictions for this specific value of lambda (aka penalty):
# First, get the processed version of the test set predictors:
diff --git a/docs/articles/articles/Scratch.html b/docs/articles/articles/Scratch.html
index 537074874..51aa7fa45 100644
--- a/docs/articles/articles/Scratch.html
+++ b/docs/articles/articles/Scratch.html
@@ -40,7 +40,7 @@
parsnip
The mode. If the model can do more than one mode, you might default this to “unknown”. In our case, since it is only a classification model, it makes sense to default it to that mode.
The argument names (sub_classes here). These should be defaulted to NULL.
-... is used to pass in other arguments to the underlying model fit functions.
This is pretty simple since the data are not exposed to this function.
@@ -236,7 +232,7 @@
func is the prediction function (in the same format as above). In many cases, packages have a predict method for their model’s class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to predict with no associated package.
-args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.
+args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.
and so on. These can be accomodated via predict.model_fit using different type arguments.
However, there are some models (e.g. glmnet, plsr, Cubist, etc.) that can make predictions for different models from the same fitted model object. The regular predict method requires prediction from a single model but the multi_predict can. The guideline is to always return the same number of rows as in new_data. This means that the .pred column is a list-column of tibbles.
For example, for a multinomial glmnet model, we leave penalty unspecified when fitting and get predictions on a sequence of values:
Note that I wrapped binomial inside of expr. If I didn’t, it would substitute the results of executing binomial inside of the expression (and that’s a mess). Using namespaces is a good idea here.
@@ -460,7 +460,7 @@
The translate function can be used to check values or set defaults once the model’s mode is known. To do this, you can create a model-specific S3 method that first calls the general method (translate.model_spec) and then makes modifications or conducts error traps.
For example, the ranger and randomForest package functions have arguments for calculating importance. One is a logical and the other is a string. Since this is likely to lead to a bunch of frustration and GH issues, we can put in a check:
However, there might be other arguments that you would like to change or allow to vary. These are accessible using the ... slot. This is a named list of arguments in the form of the underlying function being called. For example, ranger has an option to set the internal random number seed. To set this to a specific value:
However, there might be other arguments that you would like to change or allow to vary. These are accessible using set_engine. For example, ranger has an option to set the internal random number seed. To set this to a specific value:
Declare a computational engine and specific arguments
+
+
set_engine.Rd
+
+
+
+
+
set_engine is used to specify which package or system will be used
+to fit the model, along with any arguments specific to that software.
+
+
+
+
set_engine(object, engine, ...)
+
+
Arguments
+
+
+
+
object
+
A model specification.
+
+
+
engine
+
A character string for the software that should
+be used to fit the model. This is highly dependent on the type
+of model (e.g. linear regression, random forest, etc.).
+
+
+
...
+
Any optional arguments associated with the chosen computational
+engine. These are captured as quosures and can be varying().
+
+
+
+
Value
+
+
An updated model specification.
+
+
+
Examples
+
# First, set general arguments using the standardized names
+mod<-
+ logistic_reg(mixture=1/3) %>%
+ # now say how you want to fit the model and another other options
+ set_engine("glmnet", nlambda=10)
+translate(mod, engine="glmnet")
#> Logistic Regression Model Specification (classification)
+#>
+#> Main Arguments:
+#> mixture = 1/3
+#>
+#> Engine-Specific Arguments:
+#> nlambda = 10
+#>
+#> Computational engine: glmnet
+#>
+#> Model fit template:
+#> glmnet::glmnet(x = missing_arg(), y = missing_arg(), weights = missing_arg(),
+#> alpha = 1/3, nlambda = 10, family = "binomial")
diff --git a/docs/reference/surv_reg.html b/docs/reference/surv_reg.html
index 96ad45f21..f6760562f 100644
--- a/docs/reference/surv_reg.html
+++ b/docs/reference/surv_reg.html
@@ -38,7 +38,7 @@
dist: The probability distribution of the outcome.
This argument is converted to its specific names at the
time that the model is fit. Other options and argument can be
-set using the ... slot. If left to its default
+set using set_engine. If left to its default
here (NULL), the value is taken from the underlying model
functions.
If parameters need to be modified, this function can be used
@@ -83,7 +83,7 @@
parsnip
dist: The probability distribution of the outcome.
This argument is converted to its specific names at the
time that the model is fit. Other options and argument can be
-set using the ... slot. If left to its default
+set using set_engine. If left to its default
here (NULL), the value is taken from the underlying model
functions.
If parameters need to be modified, this function can be used
@@ -154,7 +154,7 @@
General Interface for Parametric Survival Models
-
surv_reg(mode="regression", dist=NULL, ...)
+
surv_reg(mode="regression", dist=NULL)
# S3 method for surv_regupdate(object, dist=NULL, fresh=FALSE, ...)
@@ -171,14 +171,6 @@
Arg
dist
A character string for the outcome distribution. "weibull" is
the default.
-
-
-
...
-
Other arguments to pass to the specific engine's
-model fit function (see the Engine Details section below). This
-should not include arguments defined by the main parameters to
-this function. For the update function, the ellipses can
-contain the primary arguments or any others.
object
@@ -189,6 +181,10 @@
Arg
A logical for whether the arguments should be
modified in-place of or replaced wholesale.
+
+
...
+
Not used for update.
+
Details
@@ -202,11 +198,32 @@
Details
Also, for the flexsurv::flexsurvfit engine, the typical
strata function cannot be used. To achieve the same effect,
the extra parameter roles can be used (as described above).
+
For surv_reg, the mode will always be "regression".
The model can be created using the fit() function using the
following engines:
R: "flexsurv", "survreg"
+
Engine Details
+
+
+
Engines may have pre-set default arguments when executing the
+model fit call. For this type of
+model, the template of the fit calls are:
#> [38;5;246m# A tibble: 4 x 4[39m
-#> name varying id type
-#> [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<lgl>[39m[23m [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<chr>[39m[23m
-#> [38;5;250m1[39m mtry FALSE only others model_spec
-#> [38;5;250m2[39m trees FALSE only others model_spec
-#> [38;5;250m3[39m min_n FALSE only others model_spec
-#> [38;5;250m4[39m others TRUE only others model_spec
#> [38;5;246m# A tibble: 4 x 4[39m
+#> name varying id type
+#> [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<lgl>[39m[23m [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<chr>[39m[23m
+#> [38;5;250m1[39m mtry FALSE only eng_args model_spec
+#> [38;5;250m2[39m trees FALSE only eng_args model_spec
+#> [38;5;250m3[39m min_n FALSE only eng_args model_spec
+#> [38;5;250m4[39m sample.fraction TRUE only eng_args model_spec
#> [38;5;246m# A tibble: 4 x 4[39m
-#> name varying id type
-#> [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<lgl>[39m[23m [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<chr>[39m[23m
-#> [38;5;250m1[39m mtry FALSE add an expr model_spec
-#> [38;5;250m2[39m trees FALSE add an expr model_spec
-#> [38;5;250m3[39m min_n FALSE add an expr model_spec
-#> [38;5;250m4[39m others FALSE add an expr model_spec
#> [38;5;246m# A tibble: 5 x 4[39m
+#> name varying id type
+#> [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<lgl>[39m[23m [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<chr>[39m[23m
+#> [38;5;250m1[39m mtry FALSE add an expr model_spec
+#> [38;5;250m2[39m trees FALSE add an expr model_spec
+#> [38;5;250m3[39m min_n FALSE add an expr model_spec
+#> [38;5;250m4[39m strata FALSE add an expr model_spec
+#> [38;5;250m5[39m sampsize TRUE add an expr model_spec
+ rand_forest() %>%
+ set_engine("ranger", classwt=c(class1=1, class2=varying())) %>%varying_args(id="list of values")
#> [38;5;246m# A tibble: 4 x 4[39m
-#> name varying id type
-#> [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<lgl>[39m[23m [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<chr>[39m[23m
-#> [38;5;250m1[39m mtry FALSE list of values model_spec
-#> [38;5;250m2[39m trees FALSE list of values model_spec
-#> [38;5;250m3[39m min_n FALSE list of values model_spec
-#> [38;5;250m4[39m others FALSE list of values model_spec
+#> name varying id type
+#> [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<lgl>[39m[23m [3m[38;5;246m<chr>[39m[23m [3m[38;5;246m<chr>[39m[23m
+#> [38;5;250m1[39m mtry FALSE list of values model_spec
+#> [38;5;250m2[39m trees FALSE list of values model_spec
+#> [38;5;250m3[39m min_n FALSE list of values model_spec
+#> [38;5;250m4[39m classwt TRUE list of values model_spec