Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Version: 0.0.0.9004
Title: A Common API to Modeling and analysis Functions
Version: 0.0.0.9005
Title: A Common API to Modeling and Analysis Functions
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. R, spark, stan, etc).
Authors@R: c(
person("Max", "Kuhn", , "max@rstudio.com", c("aut", "cre")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export(predict_raw)
export(predict_raw.model_fit)
export(rand_forest)
export(set_args)
export(set_engine)
export(set_mode)
export(show_call)
export(surv_reg)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# parsnip 0.0.0.9005

* The engine, and any associated arguments, are not specified using `set_engine`. There is no `engine` argument


# parsnip 0.0.0.9004

* Arguments to modeling functions are now captured as quosures.
Expand Down
15 changes: 11 additions & 4 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) {
x
}

check_others <- function(args, obj, core_args) {
check_eng_args <- function(args, obj, core_args) {
# Make sure that we are not trying to modify an argument that
# is explicitly protected in the method metadata or arg_key
protected_args <- unique(c(obj$protect, core_args))
Expand Down Expand Up @@ -95,10 +95,17 @@ set_args <- function(object, ...) {
if (any(main_args == i)) {
object$args[[i]] <- the_dots[[i]]
} else {
object$others[[i]] <- the_dots[[i]]
object$eng_args[[i]] <- the_dots[[i]]
}
}
object
new_model_spec(
cls = class(object)[1],
args = object$args,
eng_args = object$eng_args,
mode = object$mode,
method = NULL,
engine = object$engine
)
}

#' @rdname set_args
Expand Down Expand Up @@ -130,6 +137,6 @@ maybe_eval <- function(x) {

eval_args <- function(spec, ...) {
spec$args <- purrr::map(spec$args, maybe_eval)
spec$others <- purrr::map(spec$others, maybe_eval)
spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)
spec
}
75 changes: 30 additions & 45 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#' }
#' These arguments are converted to their specific names at the
#' time that the model is fit. Other options and argument can be
#' set using the `...` slot. If left to their defaults
#' set using the `set_engine` function. If left to their defaults
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update` can be used
#' in lieu of recreating the object from scratch.
Expand All @@ -46,11 +46,6 @@
#' @param sample_size An number for the number (or proportion) of data that is
#' exposed to the fitting routine. For `xgboost`, the sampling is done at at
#' each iteration while `C5.0` samples once during traning.
#' @param ... Other arguments to pass to the specific engine's
#' model fit function (see the Engine Details section below). This
#' should not include arguments defined by the main parameters to
#' this function. For the `update` function, the ellipses can
#' contain the primary arguments or any others.
#' @details
#' The data given to the function are not saved and are only used
#' to determine the _mode_ of the model. For `boost_tree`, the
Expand All @@ -63,17 +58,12 @@
#' \item \pkg{Spark}: `"spark"`
#' }
#'
#' Main parameter arguments (and those in `...`) can avoid
#' evaluation until the underlying function is executed by wrapping the
#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`).
#'
#'
#' @section Engine Details:
#'
#' Engines may have pre-set default arguments when executing the
#' model fit call. These can be changed by using the `...`
#' argument to pass in the preferred values. For this type of
#' model, the template of the fit calls are:
#' model fit call. For this type of model, the template of the
#' fit calls are:
#'
#' \pkg{xgboost} classification
#'
Expand Down Expand Up @@ -109,7 +99,7 @@
#' reloaded and reattached to the `parsnip` object.
#'
#' @importFrom purrr map_lgl
#' @seealso [varying()], [fit()]
#' @seealso [varying()], [fit()], [set_engine()]
#' @examples
#' boost_tree(mode = "classification", trees = 20)
#' # Parameters can be represented by a placeholder:
Expand All @@ -121,11 +111,7 @@ boost_tree <-
mtry = NULL, trees = NULL, min_n = NULL,
tree_depth = NULL, learn_rate = NULL,
loss_reduction = NULL,
sample_size = NULL,
...) {

others <- enquos(...)

sample_size = NULL) {
args <- list(
mtry = enquo(mtry),
trees = enquo(trees),
Expand All @@ -136,18 +122,14 @@ boost_tree <-
sample_size = enquo(sample_size)
)

if (!(mode %in% boost_tree_modes))
stop("`mode` should be one of: ",
paste0("'", boost_tree_modes, "'", collapse = ", "),
call. = FALSE)

no_value <- !vapply(others, null_value, logical(1))
others <- others[no_value]

out <- list(args = args, others = others,
mode = mode, method = NULL, engine = NULL)
class(out) <- make_classes("boost_tree")
out
new_model_spec(
"boost_tree",
args,
eng_args = NULL,
mode,
method = NULL,
engine = NULL
)
}

#' @export
Expand All @@ -167,6 +149,7 @@ print.boost_tree <- function(x, ...) {
#' @export
#' @inheritParams boost_tree
#' @param object A boosted tree model specification.
#' @param ... Not used for `update`.
#' @param fresh A logical for whether the arguments should be
#' modified in-place of or replaced wholesale.
#' @return An updated model specification.
Expand All @@ -183,10 +166,8 @@ update.boost_tree <-
mtry = NULL, trees = NULL, min_n = NULL,
tree_depth = NULL, learn_rate = NULL,
loss_reduction = NULL, sample_size = NULL,
fresh = FALSE,
...) {

others <- enquos(...)
fresh = FALSE, ...) {
update_dot_check(...)

args <- list(
mtry = enquo(mtry),
Expand All @@ -209,23 +190,27 @@ update.boost_tree <-
object$args[names(args)] <- args
}

if (length(others) > 0) {
if (fresh)
object$others <- others
else
object$others[names(others)] <- others
}

object
new_model_spec(
"boost_tree",
args = object$args,
eng_args = object$eng_args,
mode = object$mode,
method = NULL,
engine = object$engine
)
}

# ------------------------------------------------------------------------------

#' @export
translate.boost_tree <- function(x, engine, ...) {
translate.boost_tree <- function(x, engine = x$engine, ...) {
if (is.null(engine)) {
message("Used `engine = 'xgboost'` for translation.")
engine <- "xgboost"
}
x <- translate.default(x, engine, ...)

if (x$engine == "spark") {
if (engine == "spark") {
if (x$mode == "unknown")
stop(
"For spark boosted trees models, the mode cannot be 'unknown' ",
Expand Down
2 changes: 1 addition & 1 deletion R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ convert_form_to_xy_fit <-function(
if (indicators) {
x <- model.matrix(mod_terms, mod_frame, contrasts)
} else {
# this still ignores -vars in formula ¯\_(ツ)_/¯
# this still ignores -vars in formula
x <- model.frame(mod_terms, data)
y_cols <- attr(mod_terms, "response")
if (length(y_cols) > 0)
Expand Down
10 changes: 5 additions & 5 deletions R/descriptors.R
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,11 @@ make_descr <- function(object) {
expr_main <- map_lgl(object$args, has_exprs)
else
expr_main <- FALSE
if (length(object$others) > 0)
expr_others <- map_lgl(object$others, has_exprs)
if (length(object$eng_args) > 0)
expr_eng_args <- map_lgl(object$eng_args, has_exprs)
else
expr_others <- FALSE
any(expr_main) | any(expr_others)
expr_eng_args <- FALSE
any(expr_main) | any(expr_eng_args)
}

# Locate descriptors -----------------------------------------------------------
Expand All @@ -331,7 +331,7 @@ make_descr <- function(object) {
requires_descrs <- function(object) {
any(c(
map_lgl(object$args, has_any_descrs),
map_lgl(object$others, has_any_descrs)
map_lgl(object$eng_args, has_any_descrs)
))
}

Expand Down
40 changes: 40 additions & 0 deletions R/engines.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,43 @@ check_installs <- function(x) {
}
}
}

#' Declare a computational engine and specific arguments
#'
#' `set_engine` is used to specify which package or system will be used
#' to fit the model, along with any arguments specific to that software.
#'
#' @param object A model specification.
#' @param engine A character string for the software that should
#' be used to fit the model. This is highly dependent on the type
#' of model (e.g. linear regression, random forest, etc.).
#' @param ... Any optional arguments associated with the chosen computational
#' engine. These are captured as quosures and can be `varying()`.
#' @return An updated model specification.
#' @examples
#' # First, set general arguments using the standardized names
#' mod <-
#' logistic_reg(mixture = 1/3) %>%
#' # now say how you want to fit the model and another other options
#' set_engine("glmnet", nlambda = 10)
#' translate(mod, engine = "glmnet")
#' @export
set_engine <- function(object, engine, ...) {
if (!inherits(object, "model_spec")) {
stop("`object` should have class 'model_spec'.", call. = FALSE)
}
if (!is.character(engine) | length(engine) != 1)
stop("`engine` should be a single character value.", call. = FALSE)

object$engine <- engine
object <- check_engine(object)

new_model_spec(
cls = class(object)[1],
args = object$args,
eng_args = enquos(...),
mode = object$mode,
method = NULL,
engine = object$engine
)
}
Loading