diff --git a/DESCRIPTION b/DESCRIPTION index c7c333099..5215d6dc5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip -Version: 0.0.0.9004 -Title: A Common API to Modeling and analysis Functions +Version: 0.0.0.9005 +Title: A Common API to Modeling and Analysis Functions Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. R, spark, stan, etc). Authors@R: c( person("Max", "Kuhn", , "max@rstudio.com", c("aut", "cre")), diff --git a/NAMESPACE b/NAMESPACE index 88b02f030..7d4305157 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -102,6 +102,7 @@ export(predict_raw) export(predict_raw.model_fit) export(rand_forest) export(set_args) +export(set_engine) export(set_mode) export(show_call) export(surv_reg) diff --git a/NEWS.md b/NEWS.md index b8bfad6f6..2aacae55e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# parsnip 0.0.0.9005 + +* The engine, and any associated arguments, are not specified using `set_engine`. There is no `engine` argument + + # parsnip 0.0.0.9004 * Arguments to modeling functions are now captured as quosures. diff --git a/R/arguments.R b/R/arguments.R index 5c3f7d8f0..4db44be42 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -50,7 +50,7 @@ prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) { x } -check_others <- function(args, obj, core_args) { +check_eng_args <- function(args, obj, core_args) { # Make sure that we are not trying to modify an argument that # is explicitly protected in the method metadata or arg_key protected_args <- unique(c(obj$protect, core_args)) @@ -95,10 +95,17 @@ set_args <- function(object, ...) { if (any(main_args == i)) { object$args[[i]] <- the_dots[[i]] } else { - object$others[[i]] <- the_dots[[i]] + object$eng_args[[i]] <- the_dots[[i]] } } - object + new_model_spec( + cls = class(object)[1], + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } #' @rdname set_args @@ -130,6 +137,6 @@ maybe_eval <- function(x) { eval_args <- function(spec, ...) { spec$args <- purrr::map(spec$args, maybe_eval) - spec$others <- purrr::map(spec$others, maybe_eval) + spec$eng_args <- purrr::map(spec$eng_args, maybe_eval) spec } diff --git a/R/boost_tree.R b/R/boost_tree.R index 61f2d0f0a..f196d4e8e 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -22,7 +22,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using the `set_engine` function. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -46,11 +46,6 @@ #' @param sample_size An number for the number (or proportion) of data that is #' exposed to the fitting routine. For `xgboost`, the sampling is done at at #' each iteration while `C5.0` samples once during traning. -#' @param ... Other arguments to pass to the specific engine's -#' model fit function (see the Engine Details section below). This -#' should not include arguments defined by the main parameters to -#' this function. For the `update` function, the ellipses can -#' contain the primary arguments or any others. #' @details #' The data given to the function are not saved and are only used #' to determine the _mode_ of the model. For `boost_tree`, the @@ -63,17 +58,12 @@ #' \item \pkg{Spark}: `"spark"` #' } #' -#' Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`). -#' #' #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of -#' model, the template of the fit calls are: +#' model fit call. For this type of model, the template of the +#' fit calls are: #' #' \pkg{xgboost} classification #' @@ -109,7 +99,7 @@ #' reloaded and reattached to the `parsnip` object. #' #' @importFrom purrr map_lgl -#' @seealso [varying()], [fit()] +#' @seealso [varying()], [fit()], [set_engine()] #' @examples #' boost_tree(mode = "classification", trees = 20) #' # Parameters can be represented by a placeholder: @@ -121,11 +111,7 @@ boost_tree <- mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, - sample_size = NULL, - ...) { - - others <- enquos(...) - + sample_size = NULL) { args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -136,18 +122,14 @@ boost_tree <- sample_size = enquo(sample_size) ) - if (!(mode %in% boost_tree_modes)) - stop("`mode` should be one of: ", - paste0("'", boost_tree_modes, "'", collapse = ", "), - call. = FALSE) - - no_value <- !vapply(others, null_value, logical(1)) - others <- others[no_value] - - out <- list(args = args, others = others, - mode = mode, method = NULL, engine = NULL) - class(out) <- make_classes("boost_tree") - out + new_model_spec( + "boost_tree", + args, + eng_args = NULL, + mode, + method = NULL, + engine = NULL + ) } #' @export @@ -167,6 +149,7 @@ print.boost_tree <- function(x, ...) { #' @export #' @inheritParams boost_tree #' @param object A boosted tree model specification. +#' @param ... Not used for `update`. #' @param fresh A logical for whether the arguments should be #' modified in-place of or replaced wholesale. #' @return An updated model specification. @@ -183,10 +166,8 @@ update.boost_tree <- mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, sample_size = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) + fresh = FALSE, ...) { + update_dot_check(...) args <- list( mtry = enquo(mtry), @@ -209,23 +190,27 @@ update.boost_tree <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "boost_tree", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } # ------------------------------------------------------------------------------ #' @export -translate.boost_tree <- function(x, engine, ...) { +translate.boost_tree <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'xgboost'` for translation.") + engine <- "xgboost" + } x <- translate.default(x, engine, ...) - if (x$engine == "spark") { + if (engine == "spark") { if (x$mode == "unknown") stop( "For spark boosted trees models, the mode cannot be 'unknown' ", diff --git a/R/convert_data.R b/R/convert_data.R index 50398db26..dbd6603cf 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -76,7 +76,7 @@ convert_form_to_xy_fit <-function( if (indicators) { x <- model.matrix(mod_terms, mod_frame, contrasts) } else { - # this still ignores -vars in formula ¯\_(ツ)_/¯ + # this still ignores -vars in formula x <- model.frame(mod_terms, data) y_cols <- attr(mod_terms, "response") if (length(y_cols) > 0) diff --git a/R/descriptors.R b/R/descriptors.R index 9ff68f0df..52b17aa69 100644 --- a/R/descriptors.R +++ b/R/descriptors.R @@ -318,11 +318,11 @@ make_descr <- function(object) { expr_main <- map_lgl(object$args, has_exprs) else expr_main <- FALSE - if (length(object$others) > 0) - expr_others <- map_lgl(object$others, has_exprs) + if (length(object$eng_args) > 0) + expr_eng_args <- map_lgl(object$eng_args, has_exprs) else - expr_others <- FALSE - any(expr_main) | any(expr_others) + expr_eng_args <- FALSE + any(expr_main) | any(expr_eng_args) } # Locate descriptors ----------------------------------------------------------- @@ -331,7 +331,7 @@ make_descr <- function(object) { requires_descrs <- function(object) { any(c( map_lgl(object$args, has_any_descrs), - map_lgl(object$others, has_any_descrs) + map_lgl(object$eng_args, has_any_descrs) )) } diff --git a/R/engines.R b/R/engines.R index 013a1ba55..28a15ac5e 100644 --- a/R/engines.R +++ b/R/engines.R @@ -52,3 +52,43 @@ check_installs <- function(x) { } } } + +#' Declare a computational engine and specific arguments +#' +#' `set_engine` is used to specify which package or system will be used +#' to fit the model, along with any arguments specific to that software. +#' +#' @param object A model specification. +#' @param engine A character string for the software that should +#' be used to fit the model. This is highly dependent on the type +#' of model (e.g. linear regression, random forest, etc.). +#' @param ... Any optional arguments associated with the chosen computational +#' engine. These are captured as quosures and can be `varying()`. +#' @return An updated model specification. +#' @examples +#' # First, set general arguments using the standardized names +#' mod <- +#' logistic_reg(mixture = 1/3) %>% +#' # now say how you want to fit the model and another other options +#' set_engine("glmnet", nlambda = 10) +#' translate(mod, engine = "glmnet") +#' @export +set_engine <- function(object, engine, ...) { + if (!inherits(object, "model_spec")) { + stop("`object` should have class 'model_spec'.", call. = FALSE) + } + if (!is.character(engine) | length(engine) != 1) + stop("`engine` should be a single character value.", call. = FALSE) + + object$engine <- engine + object <- check_engine(object) + + new_model_spec( + cls = class(object)[1], + args = object$args, + eng_args = enquos(...), + mode = object$mode, + method = NULL, + engine = object$engine + ) +} diff --git a/R/fit.R b/R/fit.R index 4f240545a..ec201da36 100644 --- a/R/fit.R +++ b/R/fit.R @@ -9,7 +9,8 @@ #' code by substituting arguments, and execute the model fit #' routine. #' -#' @param object An object of class `model_spec` +#' @param object An object of class `model_spec` that has a chosen engine +#' (via [set_engine()]). #' @param formula An object of class "formula" (or one that can #' be coerced to that class): a symbolic description of the model #' to be fitted. @@ -17,15 +18,11 @@ #' below). A data frame containing all relevant variables (e.g. #' outcome(s), predictors, case weights, etc). Note: when needed, a #' \emph{named argument} should be used. -#' @param engine A character string for the software that should -#' be used to fit the model. This is highly dependent on the type -#' of model (e.g. linear regression, random forest, etc.). #' @param control A named list with elements `verbosity` and #' `catch`. See [fit_control()]. #' @param ... Not currently used; values passed here will be #' ignored. Other options required to fit the model should be -#' passed using the `others` argument in the original model -#' specification. +#' passed using `set_engine`. #' @details `fit` and `fit_xy` substitute the current arguments in the model #' specification into the computational engine's code, checks them #' for validity, then fits the model using the data and the @@ -49,21 +46,20 @@ #' library(dplyr) #' data("lending_club") #' -#' lm_mod <- logistic_reg() +#' lr_mod <- logistic_reg() #' -#' lm_mod <- logistic_reg() +#' lr_mod <- logistic_reg() #' #' using_formula <- -#' lm_mod %>% -#' fit(Class ~ funded_amnt + int_rate, -#' data = lending_club, -#' engine = "glm") +#' lr_mod %>% +#' set_engine("glm") %>% +#' fit(Class ~ funded_amnt + int_rate, data = lending_club) #' #' using_xy <- -#' lm_mod %>% +#' lr_mod %>% +#' set_engine("glm") %>% #' fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], -#' y = lending_club$Class, -#' engine = "glm") +#' y = lending_club$Class) #' #' using_formula #' using_xy @@ -83,6 +79,7 @@ #' The return value will also have a class related to the fitted model (e.g. #' `"_glm"`) before the base class of `"model_fit"`. #' +#' @seealso [set_engine()], [fit_control()], `model_spec`, `model_fit` #' @param x A matrix or data frame of predictors. #' @param y A vector, matrix or data frame of outcome data. #' @rdname fit @@ -92,11 +89,13 @@ fit.model_spec <- function(object, formula = NULL, data = NULL, - engine = object$engine, control = fit_control(), ... ) { dots <- quos(...) + if (any(names(dots) == "engine")) + stop("Use `set_engine` to supply the engine.", call. = FALSE) + if (all(c("x", "y") %in% names(dots))) stop("`fit.model_spec` is for the formula methods. Use `fit_xy` instead.", call. = FALSE) @@ -109,10 +108,8 @@ fit.model_spec <- eval_env$formula <- formula fit_interface <- check_interface(eval_env$formula, eval_env$data, cl, object) - object$engine <- engine - object <- check_engine(object) - if (engine == "spark" && !inherits(eval_env$data, "tbl_spark")) + if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) stop( "spark objects can only be used with the formula interface to `fit` ", "with a spark data object.", call. = FALSE @@ -122,7 +119,7 @@ fit.model_spec <- object <- get_method(object, engine = object$engine) check_installs(object) # TODO rewrite with pkgman - # TODO Should probably just load the namespace + load_libs(object, control$verbosity < 2) interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") @@ -178,20 +175,20 @@ fit_xy.model_spec <- function(object, x = NULL, y = NULL, - engine = object$engine, control = fit_control(), ... ) { + dots <- quos(...) + if (any(names(dots) == "engine")) + stop("Use `set_engine` to supply the engine.", call. = FALSE) cl <- match.call(expand.dots = TRUE) eval_env <- rlang::env() eval_env$x <- x eval_env$y <- y fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object) - object$engine <- engine - object <- check_engine(object) - if (engine == "spark") + if (object$engine == "spark") stop( "spark objects can only be used with the formula interface to `fit` ", "with a spark data object.", call. = FALSE @@ -201,7 +198,7 @@ fit_xy.model_spec <- object <- get_method(object, engine = object$engine) check_installs(object) # TODO rewrite with pkgman - # TODO Should probably just load the namespace + load_libs(object, control$verbosity < 2) interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") diff --git a/R/linear_reg.R b/R/linear_reg.R index f2e37817f..e0805d288 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -12,7 +12,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -25,7 +25,6 @@ #' represents the proportion of regularization that is used for the #' L2 penalty (i.e. weight decay, or ridge regression) versus L1 #' (the lasso) (`glmnet` and `spark` only). -#' #' @details #' The data given to the function are not saved and are only used #' to determine the _mode_ of the model. For `linear_reg`, the @@ -42,8 +41,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{lm} @@ -92,7 +90,7 @@ #' separately saved to disk. In a new session, the object can be #' reloaded and reattached to the `parsnip` object. #' -#' @seealso [varying()], [fit()] +#' @seealso [varying()], [fit()], [set_engine()] #' @examples #' linear_reg() #' # Parameters can be represented by a placeholder: @@ -102,36 +100,21 @@ linear_reg <- function(mode = "regression", penalty = NULL, - mixture = NULL, - ...) { - - others <- enquos(...) + mixture = NULL) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) ) - if (!(mode %in% linear_reg_modes)) - stop( - "`mode` should be one of: ", - paste0("'", linear_reg_modes, "'", collapse = ", "), - call. = FALSE - ) - - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - # write a constructor function - out <- list( + new_model_spec( + "linear_reg", args = args, - others = others, + eng_args = NULL, mode = mode, method = NULL, engine = NULL ) - class(out) <- make_classes("linear_reg") - out } #' @export @@ -162,11 +145,8 @@ print.linear_reg <- function(x, ...) { update.linear_reg <- function(object, penalty = NULL, mixture = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) - + fresh = FALSE, ...) { + update_dot_check(...) args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -182,14 +162,14 @@ update.linear_reg <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "linear_reg", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } # ------------------------------------------------------------------------------ diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 29fb60bf3..a0d67f0c1 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -12,7 +12,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -39,8 +39,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{glm} @@ -100,36 +99,21 @@ logistic_reg <- function(mode = "classification", penalty = NULL, - mixture = NULL, - ...) { - - others <- enquos(...) + mixture = NULL) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) ) - if (!(mode %in% logistic_reg_modes)) - stop( - "`mode` should be one of: ", - paste0("'", logistic_reg_modes, "'", collapse = ", "), - call. = FALSE - ) - - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - # write a constructor function - out <- list( + new_model_spec( + "logistic_reg", args = args, - others = others, + eng_args = NULL, mode = mode, method = NULL, engine = NULL ) - class(out) <- make_classes("logistic_reg") - out } #' @export @@ -160,11 +144,8 @@ print.logistic_reg <- function(x, ...) { update.logistic_reg <- function(object, penalty = NULL, mixture = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) - + fresh = FALSE, ...) { + update_dot_check(...) args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -180,14 +161,14 @@ update.logistic_reg <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "logistic_reg", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } # ------------------------------------------------------------------------------ diff --git a/R/mars.R b/R/mars.R index 6bc57b482..7835eb05b 100644 --- a/R/mars.R +++ b/R/mars.R @@ -17,7 +17,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -30,11 +30,7 @@ #' final model, including the intercept. #' @param prod_degree The highest possible interaction degree. #' @param prune_method The pruning method. -#' @details Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()]. -#' -#' The model can be created using the `fit()` function using the +#' @details The model can be created using the `fit()` function using the #' following _engines_: #' \itemize{ #' \item \pkg{R}: `"earth"` @@ -43,8 +39,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{earth} classification @@ -67,10 +62,7 @@ mars <- function(mode = "unknown", - num_terms = NULL, prod_degree = NULL, prune_method = NULL, - ...) { - - others <- enquos(...) + num_terms = NULL, prod_degree = NULL, prune_method = NULL) { args <- list( num_terms = enquo(num_terms), @@ -78,18 +70,14 @@ mars <- prune_method = enquo(prune_method) ) - if (!(mode %in% mars_modes)) - stop("`mode` should be one of: ", - paste0("'", mars_modes, "'", collapse = ", "), - call. = FALSE) - - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - out <- list(args = args, others = others, - mode = mode, method = NULL, engine = NULL) - class(out) <- make_classes("mars") - out + new_model_spec( + "mars", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) } #' @export @@ -120,11 +108,8 @@ print.mars <- function(x, ...) { update.mars <- function(object, num_terms = NULL, prod_degree = NULL, prune_method = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) - + fresh = FALSE, ...) { + update_dot_check(...) args <- list( num_terms = enquo(num_terms), prod_degree = enquo(prod_degree), @@ -141,26 +126,29 @@ update.mars <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "mars", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } # ------------------------------------------------------------------------------ #' @export -translate.mars <- function(x, engine, ...) { - +translate.mars <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'earth'` for translation.") + engine <- "earth" + } # If classification is being done, the `glm` options should be used. Check to # see if it is there and, if not, add the default value. if (x$mode == "classification") { - if (!("glm" %in% names(x$others))) { - x$others$glm <- quote(list(family = stats::binomial)) + if (!("glm" %in% names(x$eng_args))) { + x$eng_args$glm <- quote(list(family = stats::binomial)) } } @@ -182,7 +170,7 @@ check_args.mars <- function(object) { if (!is_varying(args$prune_method) && !is.null(args$prune_method) && - is.character(args$prune_method)) + !is.character(args$prune_method)) stop("`prune_method` should be a single string value", call. = FALSE) invisible(object) @@ -223,11 +211,20 @@ multi_predict._earth <- num_terms <- sort(num_terms) + # update.earth uses the values in the call so evaluate them if + # they are quosures + call_names <- names(object$fit$call) + call_names <- call_names[!(call_names %in% c("", "x", "y"))] + for (i in call_names) { + if (is_quosure(object$fit$call[[i]])) + object$fit$call[[i]] <- eval_tidy(object$fit$call[[i]]) + } + msg <- paste("Please use `keepxy = TRUE` as an option to enable submodel", "predictions with `earth`.") - if (any(names(object$spec$others) == "keepxy")) { - if(!object$spec$others$keepxy) + if (any(names(object$fit$call) == "keepxy")) { + if(!isTRUE(object$fit$call$keepxy)) stop (msg, call. = FALSE) } else stop (msg, call. = FALSE) diff --git a/R/misc.R b/R/misc.R index 5748cae92..5c80cca64 100644 --- a/R/misc.R +++ b/R/misc.R @@ -18,7 +18,7 @@ make_classes <- function(prefix) { check_empty_ellipse <- function (...) { terms <- quos(...) if (!is_empty(terms)) - stop("Please pass other arguments to the model function via `others`", call. = FALSE) + stop("Please pass other arguments to the model function via `set_engine`", call. = FALSE) terms } @@ -35,7 +35,6 @@ deparserizer <- function(x, limit = options()$width - 10) { } print_arg_list <- function(x, ...) { - others <- c("name", "call", "expression") atomic <- vapply(x, is.atomic, logical(1)) x2 <- x x2[!atomic] <- lapply(x2[!atomic], deparserizer, ...) @@ -59,10 +58,10 @@ model_printer <- function(x, ...) { non_null_args <- map(non_null_args, convert_arg) cat(print_arg_list(non_null_args), "\n", sep = "") } - if (length(x$others) > 0) { + if (length(x$eng_args) > 0) { cat("Engine-Specific Arguments:\n") - x$others <- map(x$others, convert_arg) - cat(print_arg_list(x$others), "\n", sep = "") + x$eng_args <- map(x$eng_args, convert_arg) + cat(print_arg_list(x$eng_args), "\n", sep = "") } if (!is.null(x$engine)) { cat("Computational engine:", x$engine, "\n\n") @@ -190,3 +189,30 @@ names0 <- function (num, prefix = "x") { ind <- gsub(" ", "0", ind) paste0(prefix, ind) } + + +# ------------------------------------------------------------------------------ + +update_dot_check <- function(...) { + dots <- enquos(...) + if (length(dots) > 0) + stop("Extra arguments will be ignored: ", + paste0("`", names(dots), "`", collapse = ", "), + call. = FALSE) + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +new_model_spec <- function(cls, args, eng_args, mode, method, engine) { + spec_modes <- get(paste0(cls, "_modes")) + if (!(mode %in% spec_modes)) + stop("`mode` should be one of: ", + paste0("'", spec_modes, "'", collapse = ", "), + call. = FALSE) + + out <- list(args = args, eng_args = eng_args, + mode = mode, method = method, engine = engine) + class(out) <- make_classes(cls) + out +} diff --git a/R/mlp.R b/R/mlp.R index a323b89c2..8706a46b6 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -18,7 +18,7 @@ #' #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (see above), the values are taken from the underlying model #' functions. One exception is `hidden_units` when `nnet::nnet` is used; that #' function's `size` argument has no default so a value of 5 units will be @@ -51,18 +51,13 @@ #' \item \pkg{keras}: `"keras"` #' } #' -#' Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()] (e.g. `hidden_units = expr(num_preds * 2)`). -#' #' An error is thrown if both `penalty` and `dropout` are specified for #' `keras` models. #' #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{keras} classification @@ -92,10 +87,7 @@ mlp <- function(mode = "unknown", hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, - activation = NULL, - ...) { - - others <- enquos(...) + activation = NULL) { args <- list( hidden_units = enquo(hidden_units), @@ -105,20 +97,14 @@ mlp <- activation = enquo(activation) ) - if (!(mode %in% mlp_modes)) - stop("`mode` should be one of: ", - paste0("'", mlp_modes, "'", collapse = ", "), - call. = FALSE) - - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - # write a constructor function - out <- list(args = args, others = others, - mode = mode, method = NULL, engine = NULL) - # TODO: make_classes has wrong order; go from specific to general - class(out) <- make_classes("mlp") - out + new_model_spec( + "mlp", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) } #' @export @@ -155,10 +141,8 @@ update.mlp <- function(object, hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) - + fresh = FALSE, ...) { + update_dot_check(...) args <- list( hidden_units = enquo(hidden_units), penalty = enquo(penalty), @@ -178,20 +162,24 @@ update.mlp <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "mlp", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } # ------------------------------------------------------------------------------ #' @export -translate.mlp <- function(x, engine, ...) { +translate.mlp <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'keras'` for translation.") + engine <- "keras" + } if (engine == "nnet") { if(isTRUE(is.null(quo_get_expr(x$args$hidden_units)))) { @@ -203,10 +191,10 @@ translate.mlp <- function(x, engine, ...) { if (engine == "nnet") { if (x$mode == "classification") { - if (length(x$others) == 0 || !any(names(x$others) == "linout")) + if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout")) x$method$fit$args$linout <- FALSE } else { - if (length(x$others) == 0 || !any(names(x$others) == "linout")) + if (length(x$eng_args) == 0 || !any(names(x$eng_args) == "linout")) x$method$fit$args$linout <- TRUE } } @@ -219,10 +207,6 @@ check_args.mlp <- function(object) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$hidden_units)) - if (args$hidden_units < 2) - stop("There must be at least two hidden units", call. = FALSE) - if (is.numeric(args$penalty)) if (args$penalty < 0) stop("The amount of weight decay must be >= 0.", call. = FALSE) diff --git a/R/model_object_docs.R b/R/model_object_docs.R index ed563f788..af46bc0e8 100644 --- a/R/model_object_docs.R +++ b/R/model_object_docs.R @@ -175,10 +175,12 @@ NULL #' @examples #' #' # Keep the `x` matrix if the data are not too big. -#' spec_obj <- linear_reg(x = ifelse(.obs() < 500, TRUE, FALSE)) +#' spec_obj <- +#' linear_reg() %>% +#' set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) #' spec_obj #' -#' fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars, engine = "lm") +#' fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) #' fit_obj #' #' nrow(fit_obj$fit$x) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index d9505cf57..6f6a41b43 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -12,7 +12,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -38,8 +38,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{glmnet} @@ -83,35 +82,21 @@ multinom_reg <- function(mode = "classification", penalty = NULL, - mixture = NULL, - ...) { - others <- enquos(...) + mixture = NULL) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) ) - if (!(mode %in% multinom_reg_modes)) - stop( - "`mode` should be one of: ", - paste0("'", multinom_reg_modes, "'", collapse = ", "), - call. = FALSE - ) - - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - # write a constructor function - out <- list( + new_model_spec( + "multinom_reg", args = args, - others = others, + eng_args = NULL, mode = mode, method = NULL, engine = NULL ) - class(out) <- make_classes("multinom_reg") - out } #' @export @@ -142,10 +127,8 @@ print.multinom_reg <- function(x, ...) { update.multinom_reg <- function(object, penalty = NULL, mixture = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) - + fresh = FALSE, ...) { + update_dot_check(...) args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -161,14 +144,14 @@ update.multinom_reg <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "multinom_reg", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } # ------------------------------------------------------------------------------ diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 8b374b7f6..b85c16a9c 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -19,7 +19,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update()` can be used #' in lieu of recreating the object from scratch. @@ -49,8 +49,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{kknn} (classification or regression) @@ -67,38 +66,27 @@ #' @seealso [varying()], [fit()] #' #' @examples -#' nearest_neighbor() +#' nearest_neighbor(neighbors = 11) #' #' @export nearest_neighbor <- function(mode = "unknown", neighbors = NULL, weight_func = NULL, - dist_power = NULL, - ...) { - others <- enquos(...) - + dist_power = NULL) { args <- list( neighbors = enquo(neighbors), weight_func = enquo(weight_func), dist_power = enquo(dist_power) ) - ## TODO: make a utility function here - if (!(mode %in% nearest_neighbor_modes)) { - stop("`mode` should be one of: ", - paste0("'", nearest_neighbor_modes, "'", collapse = ", "), - call. = FALSE) - } - - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - # write a constructor function - out <- list(args = args, others = others, - mode = mode, method = NULL, engine = NULL) - # TODO: make_classes has wrong order; go from specific to general - class(out) <- make_classes("nearest_neighbor") - out + new_model_spec( + "nearest_neighbor", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) } #' @export @@ -121,11 +109,8 @@ update.nearest_neighbor <- function(object, neighbors = NULL, weight_func = NULL, dist_power = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) - + fresh = FALSE, ...) { + update_dot_check(...) args <- list( neighbors = enquo(neighbors), weight_func = enquo(weight_func), @@ -142,14 +127,14 @@ update.nearest_neighbor <- function(object, object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "nearest_neighbor", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } diff --git a/R/predict.R b/R/predict.R index 5dfd42823..ea7ea7149 100644 --- a/R/predict.R +++ b/R/predict.R @@ -47,7 +47,7 @@ #' #' Quantile predictions return a tibble with a column `.pred`, which is #' a list-column. Each list element contains a tibble with columns -#' `.pred` and `.quantile` (and perhaps others). +#' `.pred` and `.quantile` (and perhaps other columns). #' #' Using `type = "raw"` with `predict.model_fit` (or using #' `predict_raw`) will return the unadulterated results of the @@ -63,7 +63,8 @@ #' #' lm_model <- #' linear_reg() %>% -#' fit(mpg ~ ., data = mtcars %>% slice(11:32), engine = "lm") +#' set_engine("lm") %>% +#' fit(mpg ~ ., data = mtcars %>% slice(11:32)) #' #' pred_cars <- #' mtcars %>% diff --git a/R/rand_forest.R b/R/rand_forest.R index 3d81e897b..4dc26ea5d 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -15,7 +15,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -38,15 +38,10 @@ #' \item \pkg{Spark}: `"spark"` #' } #' -#' Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`). -#' #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are:: #' #' \pkg{ranger} classification @@ -100,10 +95,7 @@ #' @export rand_forest <- - function(mode = "unknown", - mtry = NULL, trees = NULL, min_n = NULL, ...) { - - others <- enquos(...) + function(mode = "unknown", mtry = NULL, trees = NULL, min_n = NULL) { args <- list( mtry = enquo(mtry), @@ -111,21 +103,14 @@ rand_forest <- min_n = enquo(min_n) ) - ## TODO: make a utility function here - if (!(mode %in% rand_forest_modes)) - stop("`mode` should be one of: ", - paste0("'", rand_forest_modes, "'", collapse = ", "), - call. = FALSE) - - no_value <- !vapply(others, null_value, logical(1)) - others <- others[no_value] - - # write a constructor function - out <- list(args = args, others = others, - mode = mode, method = NULL, engine = NULL) - # TODO: make_classes has wrong order; go from specific to general - class(out) <- make_classes("rand_forest") - out + new_model_spec( + "rand_forest", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) } #' @export @@ -156,10 +141,8 @@ print.rand_forest <- function(x, ...) { update.rand_forest <- function(object, mtry = NULL, trees = NULL, min_n = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) - + fresh = FALSE, ...) { + update_dot_check(...) args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -177,20 +160,25 @@ update.rand_forest <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object + new_model_spec( + "rand_forest", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) } # ------------------------------------------------------------------------------ #' @export -translate.rand_forest <- function(x, engine, ...) { +translate.rand_forest <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'ranger'` for translation.") + engine <- "ranger" + } + x <- translate.default(x, engine, ...) # slightly cleaner code using @@ -217,7 +205,7 @@ translate.rand_forest <- function(x, engine, ...) { } # add checks to error trap or change things for this method - if (x$engine == "ranger") { + if (engine == "ranger") { if (any(names(arg_vals) == "importance")) if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) stop("`importance` should be a character value. See ?ranger::ranger.", diff --git a/R/surv_reg.R b/R/surv_reg.R index 29c3489ab..65c86b416 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -9,7 +9,7 @@ #' } #' This argument is converted to its specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to its default +#' set using `set_engine`. If left to its default #' here (`NULL`), the value is taken from the underlying model #' functions. #' @@ -42,8 +42,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{flexsurv} @@ -67,36 +66,20 @@ #' surv_reg(dist = varying()) #' #' @export -surv_reg <- - function(mode = "regression", - dist = NULL, - ...) { - others <- enquos(...) +surv_reg <- function(mode = "regression", dist = NULL) { args <- list( dist = enquo(dist) ) - if (!(mode %in% surv_reg_modes)) - stop( - "`mode` should be one of: ", - paste0("'", surv_reg_modes, "'", collapse = ", "), - call. = FALSE - ) - - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - # write a constructor function - out <- list( + new_model_spec( + "surv_reg", args = args, - others = others, + eng_args = NULL, mode = mode, method = NULL, engine = NULL ) - class(out) <- make_classes("surv_reg") - out } #' @export @@ -128,42 +111,41 @@ print.surv_reg <- function(x, ...) { #' @method update surv_reg #' @rdname surv_reg #' @export -update.surv_reg <- - function(object, - dist = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) - - args <- list( - dist = enquo(dist) - ) - - if (fresh) { - object$args <- args - } else { - null_args <- map_lgl(args, null_value) - if (any(null_args)) - args <- args[!null_args] - if (length(args) > 0) - object$args[names(args)] <- args - } - - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - - object +update.surv_reg <- function(object, dist = NULL, fresh = FALSE, ...) { + update_dot_check(...) + args <- list( + dist = enquo(dist) + ) + + if (fresh) { + object$args <- args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args } + new_model_spec( + "surv_reg", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) +} + # ------------------------------------------------------------------------------ #' @export -translate.surv_reg <- function(x, engine, ...) { +translate.surv_reg <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'survreg'` for translation.") + engine <- "survreg" + } x <- translate.default(x, engine, ...) x } diff --git a/R/translate.R b/R/translate.R index 7f1102c57..7c88e8c18 100644 --- a/R/translate.R +++ b/R/translate.R @@ -42,8 +42,10 @@ translate <- function (x, ...) #' @importFrom utils getFromNamespace #' @importFrom purrr list_modify #' @export -translate.default <- function(x, engine, ...) { +translate.default <- function(x, engine = x$engine, ...) { check_empty_ellipse(...) + if (is.null(engine)) + stop("Please set an engine.", call. = FALSE) x$engine <- engine x <- check_engine(x) @@ -60,7 +62,7 @@ translate.default <- function(x, engine, ...) { # expression unless there are dots, warn if protected args are # being altered eng_arg_key <- arg_key[[x$engine]] - x$others <- check_others(x$others, x$method$fit, eng_arg_key) + x$eng_args <- check_eng_args(x$eng_args, x$method$fit, eng_arg_key) # keep only modified args modifed_args <- !vapply(actual_args, null_value, lgl(1)) @@ -68,21 +70,20 @@ translate.default <- function(x, engine, ...) { # look for defaults if not modified in other if(length(x$method$fit$defaults) > 0) { - in_other <- names(x$method$fit$defaults) %in% names(x$others) + in_other <- names(x$method$fit$defaults) %in% names(x$eng_args) x$defaults <- x$method$fit$defaults[!in_other] } - # combine primary, others, and defaults + # combine primary, eng_args, and defaults protected <- lapply(x$method$fit$protect, function(x) expr(missing_arg())) names(protected) <- x$method$fit$protect - x$method$fit$args <- c(protected, actual_args, x$others, x$defaults) + x$method$fit$args <- c(protected, actual_args, x$eng_args, x$defaults) - # put in correct order x } -get_method <- function(x, engine, ...) { +get_method <- function(x, engine = x$engine, ...) { check_empty_ellipse(...) x$engine <- engine x <- check_engine(x) diff --git a/R/varying.R b/R/varying.R index 49f50eb55..501faa444 100644 --- a/R/varying.R +++ b/R/varying.R @@ -23,18 +23,20 @@ varying <- function() #' #' rand_forest(mtry = varying()) %>% varying_args(id = "one arg") #' -#' rand_forest(others = list(sample.fraction = varying())) %>% -#' varying_args(id = "only others") +#' rand_forest() %>% +#' set_engine("ranger", sample.fraction = varying()) %>% +#' varying_args(id = "only eng_args") #' -#' rand_forest( -#' others = list( -#' strata = expr(Class), +#' rand_forest() %>% +#' set_engine( +#' "ranger", +#' strata = expr(Class), #' sampsize = c(varying(), varying()) -#' ) -#' ) %>% -#' varying_args(id = "add an expr") +#' ) %>% +#' varying_args(id = "add an expr") #' -#' rand_forest(others = list(classwt = c(class1 = 1, class2 = varying()))) %>% +#' rand_forest() %>% +#' set_engine("ranger", classwt = c(class1 = 1, class2 = varying())) %>% #' varying_args(id = "list of values") #' #' @export @@ -55,8 +57,8 @@ varying_args.model_spec <- function(x, id = NULL, ...) { if (is.null(id)) id <- deparse(cl$x) varying_args <- map(x$args, find_varying) - varying_others <- map(x$others, find_varying) - res <- c(varying_args, varying_others) + varying_eng_args <- map(x$eng_args, find_varying) + res <- c(varying_args, varying_eng_args) res <- map_lgl(res, any) tibble( name = names(res), diff --git a/_pkgdown.yml b/_pkgdown.yml index 971b84395..cf5e1e4e3 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -36,6 +36,7 @@ reference: - model_spec - predict.model_fit - set_args + - set_engine - set_mode - translate - varying diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html index ed83c3d28..6d3e67216 100644 --- a/docs/articles/articles/Classification.html +++ b/docs/articles/articles/Classification.html @@ -40,7 +40,7 @@ parsnip
part of tidymodels - 0.0.0.9004 + 0.0.0.9005
@@ -104,27 +104,26 @@

Classification Example

#> ── Attaching packages ──────────────────────────── tidymodels 0.0.1.9000 ── #> ✔ broom 0.5.0.9001 ✔ purrr 0.2.5 #> ✔ dials 0.0.1.9000 ✔ recipes 0.1.3.9002 -#> ✔ dplyr 0.7.6 ✔ rsample 0.0.2 +#> ✔ dplyr 0.7.7 ✔ rsample 0.0.2 #> ✔ infer 0.3.1 ✔ yardstick 0.0.1.9000 #> ✔ probably 0.0.0.9000 -#> Warning: package 'dplyr' was built under R version 3.5.1 -#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ── -#> ✖ probably::as.factor() masks base::as.factor() -#> ✖ probably::as.ordered() masks base::as.ordered() -#> ✖ purrr::discard() masks scales::discard() -#> ✖ rsample::fill() masks tidyr::fill() -#> ✖ dplyr::filter() masks stats::filter() -#> ✖ dplyr::lag() masks stats::lag() -#> ✖ rsample::prepper() masks recipes::prepper() -#> ✖ recipes::step() masks stats::step() - -data(credit_data) - -set.seed(7075) -data_split <- initial_split(credit_data, strata = "Status", p = 0.75) - -credit_train <- training(data_split) -credit_test <- testing(data_split) +#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ── +#> ✖ probably::as.factor() masks base::as.factor() +#> ✖ probably::as.ordered() masks base::as.ordered() +#> ✖ purrr::discard() masks scales::discard() +#> ✖ rsample::fill() masks tidyr::fill() +#> ✖ dplyr::filter() masks stats::filter() +#> ✖ dplyr::lag() masks stats::lag() +#> ✖ rsample::prepper() masks recipes::prepper() +#> ✖ recipes::step() masks stats::step() + +data(credit_data) + +set.seed(7075) +data_split <- initial_split(credit_data, strata = "Status", p = 0.75) + +credit_train <- training(data_split) +credit_test <- testing(data_split)

A single hidden layer neural network will be used to predict a person’s credit status. To do so, the columns of the predictor matrix should be numeric and on a common scale. recipes will be used to do so.

credit_rec <- recipe(Status ~ ., data = credit_train) %>%
   step_knnimpute(Home, Job, Marital, Income, Assets, Debt) %>%
@@ -137,34 +136,32 @@ 

Classification Example

test_normalized <- bake(credit_rec, newdata = credit_test, all_predictors())

keras will be used to fit a model with 5 hidden units and uses a 10% dropout rate to regularize the model. At each training iteration (aka epoch) a random 20% of the data will be used to measure the cross-entropy of the model.

-
library(keras)
-
-set.seed(57974)
-nnet_fit <- mlp(
-  epochs = 100, hidden_units = 5, dropout = 0.1, 
-  others = list(verbose = 0, validation_split = .20)
-) %>%
-  parsnip::fit.model_spec(Status ~ ., data = juice(credit_rec), engine = "keras")
-
-nnet_fit
-#> parsnip model object
-#> 
-#> Model
-#> ___________________________________________________________________________
-#> Layer (type)                     Output Shape                  Param #     
-#> ===========================================================================
-#> dense_1 (Dense)                  (None, 5)                     115         
+
+#> dense_3 (Dense)                  (None, 2)                     12          
+#> ===========================================================================
+#> Total params: 157
+#> Trainable params: 157
+#> Non-trainable params: 0
+#> ___________________________________________________________________________

In parsnip, the predict function is only appropriate for numeric outcomes while predict_class and predict_classprob can be used for categorical outcomes.

test_results <- credit_test %>%
   select(Status) %>%
@@ -178,17 +175,17 @@ 

Classification Example

#> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 roc_auc 0.829 +#> 1 roc_auc 0.822 test_results %>% accuracy(truth = Status, estimate = `nnet class`) #> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 accuracy 0.809 +#> 1 accuracy 0.803 test_results %>% conf_mat(truth = Status, estimate = `nnet class`) #> Truth #> Prediction bad good -#> bad 182 82 -#> good 131 718
+#> bad 176 82 +#> good 137 718 @@ -680,7 +680,7 @@

List of Models

- + surv_reg() @@ -710,32 +710,6 @@

List of Models

-stan - - -× - - -× - - -✔ - - -× - - -× - - -× - - -× - - - - survreg diff --git a/docs/articles/articles/Regression.html b/docs/articles/articles/Regression.html index b6ce2bdaf..1b76a047e 100644 --- a/docs/articles/articles/Regression.html +++ b/docs/articles/articles/Regression.html @@ -40,7 +40,7 @@ parsnip
part of tidymodels - 0.0.0.9004 + 0.0.0.9005
@@ -107,27 +107,26 @@

Regression Example

#> ── Attaching packages ──────────────────────────── tidymodels 0.0.1.9000 ── #> ✔ broom 0.5.0.9001 ✔ purrr 0.2.5 #> ✔ dials 0.0.1.9000 ✔ recipes 0.1.3.9002 -#> ✔ dplyr 0.7.6 ✔ rsample 0.0.2 +#> ✔ dplyr 0.7.7 ✔ rsample 0.0.2 #> ✔ infer 0.3.1 ✔ tibble 1.4.2 #> ✔ probably 0.0.0.9000 ✔ yardstick 0.0.1.9000 -#> Warning: package 'dplyr' was built under R version 3.5.1 -#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ── -#> ✖ probably::as.factor() masks base::as.factor() -#> ✖ probably::as.ordered() masks base::as.ordered() -#> ✖ dplyr::combine() masks randomForest::combine() -#> ✖ purrr::discard() masks scales::discard() -#> ✖ rsample::fill() masks tidyr::fill() -#> ✖ dplyr::filter() masks stats::filter() -#> ✖ dplyr::lag() masks stats::lag() -#> ✖ ggplot2::margin() masks randomForest::margin() -#> ✖ rsample::prepper() masks recipes::prepper() -#> ✖ recipes::step() masks stats::step() - -set.seed(4595) -data_split <- initial_split(ames, strata = "Sale_Price", p = 0.75) - -ames_train <- training(data_split) -ames_test <- testing(data_split) +#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ── +#> ✖ probably::as.factor() masks base::as.factor() +#> ✖ probably::as.ordered() masks base::as.ordered() +#> ✖ dplyr::combine() masks randomForest::combine() +#> ✖ purrr::discard() masks scales::discard() +#> ✖ rsample::fill() masks tidyr::fill() +#> ✖ dplyr::filter() masks stats::filter() +#> ✖ dplyr::lag() masks stats::lag() +#> ✖ ggplot2::margin() masks randomForest::margin() +#> ✖ rsample::prepper() masks recipes::prepper() +#> ✖ recipes::step() masks stats::step() + +set.seed(4595) +data_split <- initial_split(ames, strata = "Sale_Price", p = 0.75) + +ames_train <- training(data_split) +ames_test <- testing(data_split)

Random Forests

@@ -141,30 +140,31 @@

parsnip gives two different interfaces to the models: the formula and non-formula interfaces. Let’s start with the non-formula interface:

+rf_xy_fit <- + rf_defaults %>% + set_engine("ranger") %>% + fit_xy( + x = ames_train[, preds], + y = log10(ames_train$Sale_Price) + ) +rf_xy_fit +#> parsnip model object +#> +#> Ranger result +#> +#> Call: +#> ranger::ranger(formula = formula, data = data, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1)) +#> +#> Type: Regression +#> Number of trees: 500 +#> Sample size: 2199 +#> Number of independent variables: 5 +#> Mtry: 2 +#> Target node size: 5 +#> Variable importance mode: none +#> Splitrule: variance +#> OOB prediction error (MSE): 0.00866 +#> R squared (OOB): 0.727

The non-formula interface doesn’t do anything to the predictors before giving it to the underlying model function. This particular model does not require indicator variables to be create prior to the model (note that the output shows “Number of independent variables: 5”).

For regression models, the basic predict method can be used and returns a tibble with a column named .pred:

test_results <- ames_test %>%
@@ -204,10 +204,10 @@ 

Now, for illustration, let’s use the formula method using some new parameter values:

-

Suppose that there was some feature in the randomForest package that we’d like to evaluate. To do so, the only part of the syntaxt that needs to change is the engine argument:

+

Suppose that there was some feature in the randomForest package that we’d like to evaluate. To do so, the only part of the syntaxt that needs to change is the set_engine argument:

Look at the formula code that was printed out, one function uses the argument name ntree and the other uses num.trees. parsnip doesn’t require you to know the specific names of the main arguments.

Now suppose that we want to modify the value of mtry based on the number of predictors in the data. Usually, the default value would be floor(sqrt(num_predictors)). To use a pure bagging model would require an mtry value equal to the total number of parameters. There may be cases where you may not know how many predictors are going to be present (perhaps due to the generation of indicator variables or a variable filter) so that might be difficult to know exactly.

-

When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.

+

When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.

Two relevant descriptors for what we are about to do are:

If penalty were not specified, all of the lambda values would be computed.

To get the predictions for this specific value of lambda (aka penalty):

# First, get the processed version of the test set predictors:
diff --git a/docs/articles/articles/Scratch.html b/docs/articles/articles/Scratch.html
index 537074874..51aa7fa45 100644
--- a/docs/articles/articles/Scratch.html
+++ b/docs/articles/articles/Scratch.html
@@ -40,7 +40,7 @@
         parsnip
         
part of tidymodels - 0.0.0.9004 + 0.0.0.9005
@@ -153,11 +153,11 @@

  • The mode. If the model can do more than one mode, you might default this to “unknown”. In our case, since it is only a classification model, it makes sense to default it to that mode.
  • The argument names (sub_classes here). These should be defaulted to NULL.
  • -... is used to pass in other arguments to the underlying model fit functions.
  • +... are not used in the main model function.

    A basic version of the function is:

    + args <- list(sub_classes = rlang::enquo(sub_classes)) + + # Save some empty slots for future parts of the specification + out <- list(args = args, eng_args = NULL, + mode = mode, method = NULL, engine = NULL) + + # set classes in the correct order + class(out) <- make_classes("mixture_da") + out + }

    This is pretty simple since the data are not exposed to this function.

    @@ -236,7 +232,7 @@

  • func is the prediction function (in the same format as above). In many cases, packages have a predict method for their model’s class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to predict with no associated package.
  • -args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.
  • +args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.

    Let’s try it on the iris data:

    set.seed(4622)
     iris_split <- initial_split(iris, prop = 0.90)
    @@ -307,64 +304,65 @@ 

    mda_spec <- mixture_da(sub_classes = 2) mda_fit <- mda_spec %>% - fit(Species ~ ., data = iris_train, engine = "mda") -mda_fit -#> parsnip model object -#> -#> Call: -#> mda::mda(formula = formula, data = data, sub_classes = ~2) -#> -#> Dimension: 4 -#> -#> Percent Between-Group Variance Explained: -#> v1 v2 v3 v4 -#> 94.9 97.9 99.8 100.0 -#> -#> Degrees of Freedom (per dimension): 5 -#> -#> Training Misclassification Error: 0.0221 ( N = 136 ) -#> -#> Deviance: 12.3 - -predict(mda_fit, new_data = iris_test) %>% - bind_cols(iris_test %>% select(Species)) -#> # A tibble: 14 x 2 -#> .pred_class Species -#> <fct> <fct> -#> 1 setosa setosa -#> 2 setosa setosa -#> 3 setosa setosa -#> 4 versicolor versicolor -#> 5 versicolor versicolor -#> 6 versicolor versicolor -#> 7 versicolor versicolor -#> 8 versicolor versicolor -#> 9 versicolor versicolor -#> 10 versicolor versicolor -#> 11 versicolor versicolor -#> 12 virginica virginica -#> 13 virginica virginica -#> 14 virginica virginica - -predict(mda_fit, new_data = iris_test, type = "prob") %>% - bind_cols(iris_test %>% select(Species)) -#> # A tibble: 14 x 4 -#> .pred_setosa .pred_versicolor .pred_virginica Species -#> <dbl> <dbl> <dbl> <fct> -#> 1 1.00e+ 0 2.62e-32 7.10e-65 setosa -#> 2 1.00e+ 0 1.36e-25 2.36e-56 setosa -#> 3 1.00e+ 0 9.11e-29 1.33e-60 setosa -#> 4 1.76e-38 10.00e- 1 1.97e- 7 versicolor -#> 5 5.64e-36 9.95e- 1 5.03e- 3 versicolor -#> 6 6.84e-22 10.00e- 1 9.83e- 9 versicolor -#> 7 2.54e-37 9.22e- 1 7.83e- 2 versicolor -#> 8 2.70e-37 9.99e- 1 1.34e- 3 versicolor -#> 9 1.81e-37 8.06e- 1 1.94e- 1 versicolor -#> 10 9.83e-35 9.93e- 1 7.27e- 3 versicolor -#> 11 4.04e-37 9.97e- 1 3.00e- 3 versicolor -#> 12 1.93e-55 1.44e- 1 8.56e- 1 virginica -#> 13 1.21e-50 4.19e- 1 5.81e- 1 virginica -#> 14 2.08e-50 2.07e- 1 7.93e- 1 virginica

    + set_engine("mda") %>% + fit(Species ~ ., data = iris_train) +mda_fit +#> parsnip model object +#> +#> Call: +#> mda::mda(formula = formula, data = data, sub_classes = ~2) +#> +#> Dimension: 4 +#> +#> Percent Between-Group Variance Explained: +#> v1 v2 v3 v4 +#> 94.9 97.9 99.8 100.0 +#> +#> Degrees of Freedom (per dimension): 5 +#> +#> Training Misclassification Error: 0.0221 ( N = 136 ) +#> +#> Deviance: 12.3 + +predict(mda_fit, new_data = iris_test) %>% + bind_cols(iris_test %>% select(Species)) +#> # A tibble: 14 x 2 +#> .pred_class Species +#> <fct> <fct> +#> 1 setosa setosa +#> 2 setosa setosa +#> 3 setosa setosa +#> 4 versicolor versicolor +#> 5 versicolor versicolor +#> 6 versicolor versicolor +#> 7 versicolor versicolor +#> 8 versicolor versicolor +#> 9 versicolor versicolor +#> 10 versicolor versicolor +#> 11 versicolor versicolor +#> 12 virginica virginica +#> 13 virginica virginica +#> 14 virginica virginica + +predict(mda_fit, new_data = iris_test, type = "prob") %>% + bind_cols(iris_test %>% select(Species)) +#> # A tibble: 14 x 4 +#> .pred_setosa .pred_versicolor .pred_virginica Species +#> <dbl> <dbl> <dbl> <fct> +#> 1 1.00e+ 0 2.62e-32 7.10e-65 setosa +#> 2 1.00e+ 0 1.36e-25 2.36e-56 setosa +#> 3 1.00e+ 0 9.11e-29 1.33e-60 setosa +#> 4 1.76e-38 10.00e- 1 1.97e- 7 versicolor +#> 5 5.64e-36 9.95e- 1 5.03e- 3 versicolor +#> 6 6.84e-22 10.00e- 1 9.83e- 9 versicolor +#> 7 2.54e-37 9.22e- 1 7.83e- 2 versicolor +#> 8 2.70e-37 9.99e- 1 1.34e- 3 versicolor +#> 9 1.81e-37 8.06e- 1 1.94e- 1 versicolor +#> 10 9.83e-35 9.93e- 1 7.27e- 3 versicolor +#> 11 4.04e-37 9.97e- 1 3.00e- 3 versicolor +#> 12 1.93e-55 1.44e- 1 8.56e- 1 virginica +#> 13 1.21e-50 4.19e- 1 5.81e- 1 virginica +#> 14 2.08e-50 2.07e- 1 7.93e- 1 virginica

    @@ -385,25 +383,26 @@

    and so on. These can be accomodated via predict.model_fit using different type arguments.

    However, there are some models (e.g. glmnet, plsr, Cubist, etc.) that can make predictions for different models from the same fitted model object. The regular predict method requires prediction from a single model but the multi_predict can. The guideline is to always return the same number of rows as in new_data. This means that the .pred column is a list-column of tibbles.

    For example, for a multinomial glmnet model, we leave penalty unspecified when fitting and get predictions on a sequence of values:

    - +

    This can be easily expanded to remove the list columns:

    +logistic_reg() %>% + set_engine("glm", family = stats::binomial(link = "probit")) %>% + translate() +#> Logistic Regression Model Specification (classification) +#> +#> Engine-Specific Arguments: +#> family = stats::binomial(link = "probit") +#> +#> Computational engine: glm +#> +#> Model fit template: +#> stats::glm(formula = missing_arg(), data = missing_arg(), weights = missing_arg(), +#> family = stats::binomial(link = "probit"))

    That’s what defaults are for.

    Note that I wrapped binomial inside of expr. If I didn’t, it would substitute the results of executing binomial inside of the expression (and that’s a mess). Using namespaces is a good idea here.

    @@ -460,7 +460,7 @@

    The translate function can be used to check values or set defaults once the model’s mode is known. To do this, you can create a model-specific S3 method that first calls the general method (translate.model_spec) and then makes modifications or conducts error traps.

    For example, the ranger and randomForest package functions have arguments for calculating importance. One is a logical and the other is a string. Since this is likely to lead to a bunch of frustration and GH issues, we can put in a check:

    # Simplified version
    -translate.rand_forest <- function (x, engine, ...){
    +translate.rand_forest <- function (x, engine = x$engine, ...){
       # Run the general method to get the real arguments in place
       x <- translate.default(x, engine, ...)
       
    @@ -468,7 +468,7 @@ 

    arg_vals <- x$method$fit$args # Check and see if they make sense for the engine and/or mode: - if (x$engine == "ranger") { + if (engine == "ranger") { if (any(names(arg_vals) == "importance")) # We want to check the type of `importance` but it is a quosure. We first # get the expression. It is is logical, the value of `quo_get_expr` will diff --git a/docs/articles/index.html b/docs/articles/index.html index cc27c2b37..4a511d2a9 100644 --- a/docs/articles/index.html +++ b/docs/articles/index.html @@ -69,7 +69,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005

    diff --git a/docs/articles/parsnip_Intro.html b/docs/articles/parsnip_Intro.html index 9cf95f251..c7b8fefef 100644 --- a/docs/articles/parsnip_Intro.html +++ b/docs/articles/parsnip_Intro.html @@ -40,7 +40,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    @@ -152,25 +152,23 @@

    The arguments to the default function are:

    -

    However, there might be other arguments that you would like to change or allow to vary. These are accessible using the ... slot. This is a named list of arguments in the form of the underlying function being called. For example, ranger has an option to set the internal random number seed. To set this to a specific value:

    - +

    However, there might be other arguments that you would like to change or allow to vary. These are accessible using set_engine. For example, ranger has an option to set the internal random number seed. To set this to a specific value:

    + +#> Computational engine: ranger

    @@ -184,7 +182,8 @@

    For example, rf_with_seed above is not ready for fitting due the varying() parameter. We can set that parameter’s value and then create the model fit:

    rf_with_seed %>% 
       set_args(mtry = 4) %>% 
    -  fit(mpg ~ ., data = mtcars, engine = "ranger")
    + set_engine("ranger") %>% + fit(mpg ~ ., data = mtcars)

    #> parsnip model object
     #> 
     #> Ranger result
    @@ -206,7 +205,8 @@ 

    set.seed(56982)
     rf_with_seed %>% 
       set_args(mtry = 4) %>% 
    -  fit(mpg ~ ., data = mtcars, engine = "randomForest")
    + set_engine("randomForest") %>% + fit(mpg ~ ., data = mtcars)

    #> parsnip model object
     #> 
     #> 
    diff --git a/docs/authors.html b/docs/authors.html
    index ef26bfe35..5a274292f 100644
    --- a/docs/authors.html
    +++ b/docs/authors.html
    @@ -69,7 +69,7 @@
             parsnip
             
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/index.html b/docs/index.html index dc0d8f38c..8c13f7d13 100644 --- a/docs/index.html +++ b/docs/index.html @@ -5,12 +5,12 @@ -A Common API to Modeling and analysis Functions • parsnip +A Common API to Modeling and Analysis Functions • parsnip - + + + + + + + + +Declare a computational engine and specific arguments — set_engine • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + +
    + +
    +
    + + +
    + +

    set_engine is used to specify which package or system will be used +to fit the model, along with any arguments specific to that software.

    + +
    + +
    set_engine(object, engine, ...)
    + +

    Arguments

    + + + + + + + + + + + + + + +
    object

    A model specification.

    engine

    A character string for the software that should +be used to fit the model. This is highly dependent on the type +of model (e.g. linear regression, random forest, etc.).

    ...

    Any optional arguments associated with the chosen computational +engine. These are captured as quosures and can be varying().

    + +

    Value

    + +

    An updated model specification.

    + + +

    Examples

    +
    # First, set general arguments using the standardized names +mod <- + logistic_reg(mixture = 1/3) %>% + # now say how you want to fit the model and another other options + set_engine("glmnet", nlambda = 10) +translate(mod, engine = "glmnet")
    #> Logistic Regression Model Specification (classification) +#> +#> Main Arguments: +#> mixture = 1/3 +#> +#> Engine-Specific Arguments: +#> nlambda = 10 +#> +#> Computational engine: glmnet +#> +#> Model fit template: +#> glmnet::glmnet(x = missing_arg(), y = missing_arg(), weights = missing_arg(), +#> alpha = 1/3, nlambda = 10, family = "binomial")
    +
    + +
    + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn. + Site built by pkgdown. +

    +
    +
    +
    + + + + + + diff --git a/docs/reference/show_call.html b/docs/reference/show_call.html index 38a10af23..42993be17 100644 --- a/docs/reference/show_call.html +++ b/docs/reference/show_call.html @@ -72,7 +72,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/reference/surv_reg.html b/docs/reference/surv_reg.html index 96ad45f21..f6760562f 100644 --- a/docs/reference/surv_reg.html +++ b/docs/reference/surv_reg.html @@ -38,7 +38,7 @@ dist: The probability distribution of the outcome. This argument is converted to its specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to its default +set using set_engine. If left to its default here (NULL), the value is taken from the underlying model functions. If parameters need to be modified, this function can be used @@ -83,7 +83,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    @@ -146,7 +146,7 @@

    General Interface for Parametric Survival Models

  • dist: The probability distribution of the outcome.

  • This argument is converted to its specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to its default +set using set_engine. If left to its default here (NULL), the value is taken from the underlying model functions.

    If parameters need to be modified, this function can be used @@ -154,7 +154,7 @@

    General Interface for Parametric Survival Models

    -
    surv_reg(mode = "regression", dist = NULL, ...)
    +    
    surv_reg(mode = "regression", dist = NULL)
     
     # S3 method for surv_reg
     update(object, dist = NULL, fresh = FALSE, ...)
    @@ -171,14 +171,6 @@

    Arg dist

    A character string for the outcome distribution. "weibull" is the default.

    - - - ... -

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object @@ -189,6 +181,10 @@

    Arg

    A logical for whether the arguments should be modified in-place of or replaced wholesale.

    + + ... +

    Not used for update.

    +

    Details

    @@ -202,11 +198,32 @@

    Details

    Also, for the flexsurv::flexsurvfit engine, the typical strata function cannot be used. To achieve the same effect, the extra parameter roles can be used (as described above).

    +

    For surv_reg, the mode will always be "regression".

    The model can be created using the fit() function using the following engines:

    • R: "flexsurv", "survreg"

    +

    Engine Details

    + + +

    Engines may have pre-set default arguments when executing the +model fit call. For this type of +model, the template of the fit calls are:

    +

    flexsurv

    +

    +flexsurv::flexsurvreg(formula = missing_arg(), data = missing_arg(), 
    +    weights = missing_arg())
    +

    +

    survreg

    +

    +survival::survreg(formula = missing_arg(), data = missing_arg(), 
    +    weights = missing_arg(), model = TRUE)
    +

    +

    Note that model = TRUE is needed to produce quantile +predictions when there is a stratification variable and can be +overridden in other cases.

    +

    References

    Jackson, C. (2016). flexsurv: A Platform for Parametric Survival @@ -243,6 +260,8 @@

    Contents

  • Details
  • +
  • Engine Details
  • +
  • References
  • See also
  • diff --git a/docs/reference/translate.html b/docs/reference/translate.html index 5e5ebf014..3c3abc423 100644 --- a/docs/reference/translate.html +++ b/docs/reference/translate.html @@ -74,7 +74,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/reference/type_sum.model_spec.html b/docs/reference/type_sum.model_spec.html index 32fd9ff1b..ed0e8e924 100644 --- a/docs/reference/type_sum.model_spec.html +++ b/docs/reference/type_sum.model_spec.html @@ -73,7 +73,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/reference/varying.html b/docs/reference/varying.html index a0e72bbcd..c87756cda 100644 --- a/docs/reference/varying.html +++ b/docs/reference/varying.html @@ -72,7 +72,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/reference/varying_args.html b/docs/reference/varying_args.html index 5dd84b50e..d9fe20371 100644 --- a/docs/reference/varying_args.html +++ b/docs/reference/varying_args.html @@ -73,7 +73,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    @@ -187,35 +187,38 @@

    Examp #> 1 mtry TRUE one arg model_spec #> 2 trees FALSE one arg model_spec #> 3 min_n FALSE one arg model_spec
    -rand_forest(others = list(sample.fraction = varying())) %>% - varying_args(id = "only others")
    #> # A tibble: 4 x 4 -#> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE only others model_spec -#> 2 trees FALSE only others model_spec -#> 3 min_n FALSE only others model_spec -#> 4 others TRUE only others model_spec
    -rand_forest( - others = list( - strata = expr(Class), +rand_forest() %>% + set_engine("ranger", sample.fraction = varying()) %>% + varying_args(id = "only eng_args")
    #> # A tibble: 4 x 4 +#> name varying id type +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE only eng_args model_spec +#> 2 trees FALSE only eng_args model_spec +#> 3 min_n FALSE only eng_args model_spec +#> 4 sample.fraction TRUE only eng_args model_spec
    +rand_forest() %>% + set_engine( + "ranger", + strata = expr(Class), sampsize = c(varying(), varying()) - ) -) %>% - varying_args(id = "add an expr")
    #> # A tibble: 4 x 4 -#> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE add an expr model_spec -#> 2 trees FALSE add an expr model_spec -#> 3 min_n FALSE add an expr model_spec -#> 4 others FALSE add an expr model_spec
    - rand_forest(others = list(classwt = c(class1 = 1, class2 = varying()))) %>% + ) %>% + varying_args(id = "add an expr")
    #> # A tibble: 5 x 4 +#> name varying id type +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE add an expr model_spec +#> 2 trees FALSE add an expr model_spec +#> 3 min_n FALSE add an expr model_spec +#> 4 strata FALSE add an expr model_spec +#> 5 sampsize TRUE add an expr model_spec
    + rand_forest() %>% + set_engine("ranger", classwt = c(class1 = 1, class2 = varying())) %>% varying_args(id = "list of values")
    #> # A tibble: 4 x 4 -#> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE list of values model_spec -#> 2 trees FALSE list of values model_spec -#> 3 min_n FALSE list of values model_spec -#> 4 others FALSE list of values model_spec
    +#> name varying id type +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE list of values model_spec +#> 2 trees FALSE list of values model_spec +#> 3 min_n FALSE list of values model_spec +#> 4 classwt TRUE list of values model_spec

    diff --git a/docs/reference/xgb_train.html b/docs/reference/xgb_train.html index 379f5d112..0df0976f6 100644 --- a/docs/reference/xgb_train.html +++ b/docs/reference/xgb_train.html @@ -73,7 +73,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/man/boost_tree.Rd b/man/boost_tree.Rd index 1e55def1d..6ad9a2bde 100644 --- a/man/boost_tree.Rd +++ b/man/boost_tree.Rd @@ -7,7 +7,7 @@ \usage{ boost_tree(mode = "unknown", mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, - loss_reduction = NULL, sample_size = NULL, ...) + loss_reduction = NULL, sample_size = NULL) \method{update}{boost_tree}(object, mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, @@ -41,16 +41,12 @@ to split further (\code{xgboost} only).} exposed to the fitting routine. For \code{xgboost}, the sampling is done at at each iteration while \code{C5.0} samples once during traning.} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A boosted tree model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \value{ An updated model specification. @@ -76,7 +72,7 @@ to split further. } These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using the \code{set_engine} function. If left to their defaults here (\code{NULL}), the values are taken from the underlying model functions. If parameters need to be modified, \code{update} can be used in lieu of recreating the object from scratch. @@ -92,10 +88,6 @@ following \emph{engines}: \item \pkg{R}: \code{"xgboost"}, \code{"C5.0"} \item \pkg{Spark}: \code{"spark"} } - -Main parameter arguments (and those in \code{...}) can avoid -evaluation until the underlying function is executed by wrapping the -argument in \code{\link[rlang:expr]{rlang::expr()}} (e.g. \code{mtry = expr(floor(sqrt(p)))}). } \note{ For models created using the spark engine, there are @@ -115,9 +107,8 @@ reloaded and reattached to the \code{parsnip} object. Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of -model, the template of the fit calls are: +model fit call. For this type of model, the template of the +fit calls are: \pkg{xgboost} classification @@ -150,5 +141,5 @@ update(model, mtry = 1) update(model, mtry = 1, fresh = TRUE) } \seealso{ -\code{\link[=varying]{varying()}}, \code{\link[=fit]{fit()}} +\code{\link[=varying]{varying()}}, \code{\link[=fit]{fit()}}, \code{\link[=set_engine]{set_engine()}} } diff --git a/man/fit.Rd b/man/fit.Rd index 1a003b45d..54b1abcf9 100644 --- a/man/fit.Rd +++ b/man/fit.Rd @@ -6,13 +6,14 @@ \title{Fit a Model Specification to a Dataset} \usage{ \method{fit}{model_spec}(object, formula = NULL, data = NULL, - engine = object$engine, control = fit_control(), ...) + control = fit_control(), ...) \method{fit_xy}{model_spec}(object, x = NULL, y = NULL, - engine = object$engine, control = fit_control(), ...) + control = fit_control(), ...) } \arguments{ -\item{object}{An object of class \code{model_spec}} +\item{object}{An object of class \code{model_spec} that has a chosen engine +(via \code{\link[=set_engine]{set_engine()}}).} \item{formula}{An object of class "formula" (or one that can be coerced to that class): a symbolic description of the model @@ -23,17 +24,12 @@ below). A data frame containing all relevant variables (e.g. outcome(s), predictors, case weights, etc). Note: when needed, a \emph{named argument} should be used.} -\item{engine}{A character string for the software that should -be used to fit the model. This is highly dependent on the type -of model (e.g. linear regression, random forest, etc.).} - \item{control}{A named list with elements \code{verbosity} and \code{catch}. See \code{\link[=fit_control]{fit_control()}}.} \item{...}{Not currently used; values passed here will be ignored. Other options required to fit the model should be -passed using the \code{others} argument in the original model -specification.} +passed using \code{set_engine}.} \item{x}{A matrix or data frame of predictors.} @@ -86,22 +82,24 @@ objects used to fit the model. library(dplyr) data("lending_club") -lm_mod <- logistic_reg() +lr_mod <- logistic_reg() -lm_mod <- logistic_reg() +lr_mod <- logistic_reg() using_formula <- - lm_mod \%>\% - fit(Class ~ funded_amnt + int_rate, - data = lending_club, - engine = "glm") + lr_mod \%>\% + set_engine("glm") \%>\% + fit(Class ~ funded_amnt + int_rate, data = lending_club) using_xy <- - lm_mod \%>\% + lr_mod \%>\% + set_engine("glm") \%>\% fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], - y = lending_club$Class, - engine = "glm") + y = lending_club$Class) using_formula using_xy } +\seealso{ +\code{\link[=set_engine]{set_engine()}}, \code{\link[=fit_control]{fit_control()}}, \code{model_spec}, \code{model_fit} +} diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index e227b9796..f395732d7 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -5,7 +5,7 @@ \alias{update.linear_reg} \title{General Interface for Linear Regression Models} \usage{ -linear_reg(mode = "regression", penalty = NULL, mixture = NULL, ...) +linear_reg(mode = "regression", penalty = NULL, mixture = NULL) \method{update}{linear_reg}(object, penalty = NULL, mixture = NULL, fresh = FALSE, ...) @@ -22,16 +22,12 @@ represents the proportion of regularization that is used for the L2 penalty (i.e. weight decay, or ridge regression) versus L1 (the lasso) (\code{glmnet} and \code{spark} only).} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A linear regression model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \description{ \code{linear_reg} is a way to generate a \emph{specification} of a model @@ -46,7 +42,7 @@ the model. Note that this will be ignored for some engines. } These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using \code{set_engine}. If left to their defaults here (\code{NULL}), the values are taken from the underlying model functions. If parameters need to be modified, \code{update} can be used in lieu of recreating the object from scratch. @@ -82,8 +78,7 @@ reloaded and reattached to the \code{parsnip} object. Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are: \pkg{lm} @@ -130,5 +125,5 @@ update(model, penalty = 1) update(model, penalty = 1, fresh = TRUE) } \seealso{ -\code{\link[=varying]{varying()}}, \code{\link[=fit]{fit()}} +\code{\link[=varying]{varying()}}, \code{\link[=fit]{fit()}}, \code{\link[=set_engine]{set_engine()}} } diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index d466ef684..0b2918a46 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -5,8 +5,7 @@ \alias{update.logistic_reg} \title{General Interface for Logistic Regression Models} \usage{ -logistic_reg(mode = "classification", penalty = NULL, mixture = NULL, - ...) +logistic_reg(mode = "classification", penalty = NULL, mixture = NULL) \method{update}{logistic_reg}(object, penalty = NULL, mixture = NULL, fresh = FALSE, ...) @@ -23,16 +22,12 @@ represents the proportion of regularization that is used for the L2 penalty (i.e. weight decay, or ridge regression) versus L1 (the lasso) (\code{glmnet} and \code{spark} only).} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A logistic regression model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \description{ \code{logistic_reg} is a way to generate a \emph{specification} of a model @@ -47,7 +42,7 @@ the model. Note that this will be ignored for some engines. } These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using \code{set_engine}. If left to their defaults here (\code{NULL}), the values are taken from the underlying model functions. If parameters need to be modified, \code{update} can be used in lieu of recreating the object from scratch. @@ -81,8 +76,7 @@ reloaded and reattached to the \code{parsnip} object. Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are: \pkg{glm} diff --git a/man/mars.Rd b/man/mars.Rd index 9f4d25e03..090e0b77f 100644 --- a/man/mars.Rd +++ b/man/mars.Rd @@ -6,7 +6,7 @@ \title{General Interface for MARS} \usage{ mars(mode = "unknown", num_terms = NULL, prod_degree = NULL, - prune_method = NULL, ...) + prune_method = NULL) \method{update}{mars}(object, num_terms = NULL, prod_degree = NULL, prune_method = NULL, fresh = FALSE, ...) @@ -23,16 +23,12 @@ final model, including the intercept.} \item{prune_method}{The pruning method.} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A MARS model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \description{ \code{mars} is a way to generate a \emph{specification} of a model before @@ -50,16 +46,12 @@ in \code{?earth}. } These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using \code{set_engine}. If left to their defaults here (\code{NULL}), the values are taken from the underlying model functions. If parameters need to be modified, \code{update} can be used in lieu of recreating the object from scratch. } \details{ -Main parameter arguments (and those in \code{...}) can avoid -evaluation until the underlying function is executed by wrapping the -argument in \code{\link[rlang:expr]{rlang::expr()}}. - The model can be created using the \code{fit()} function using the following \emph{engines}: \itemize{ @@ -70,8 +62,7 @@ following \emph{engines}: Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are: \pkg{earth} classification diff --git a/man/mlp.Rd b/man/mlp.Rd index 807fd04ae..11a32e343 100644 --- a/man/mlp.Rd +++ b/man/mlp.Rd @@ -6,7 +6,7 @@ \title{General Interface for Single Layer Neural Network} \usage{ mlp(mode = "unknown", hidden_units = NULL, penalty = NULL, - dropout = NULL, epochs = NULL, activation = NULL, ...) + dropout = NULL, epochs = NULL, activation = NULL) \method{update}{mlp}(object, hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE, @@ -33,16 +33,12 @@ function between the hidden and output layers is automatically set to either "linear" or "softmax" depending on the type of outcome. Possible values are: "linear", "softmax", "relu", and "elu"} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A random forest model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \description{ \code{mlp}, for multilayer perceptron, is a way to generate a \emph{specification} of @@ -67,7 +63,7 @@ in lieu of recreating the object from scratch. \details{ These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using \code{set_engine}. If left to their defaults here (see above), the values are taken from the underlying model functions. One exception is \code{hidden_units} when \code{nnet::nnet} is used; that function's \code{size} argument has no default so a value of 5 units will be @@ -83,10 +79,6 @@ following \emph{engines}: \item \pkg{keras}: \code{"keras"} } -Main parameter arguments (and those in \code{...}) can avoid -evaluation until the underlying function is executed by wrapping the -argument in \code{\link[rlang:expr]{rlang::expr()}} (e.g. \code{hidden_units = expr(num_preds * 2)}). - An error is thrown if both \code{penalty} and \code{dropout} are specified for \code{keras} models. } @@ -94,8 +86,7 @@ An error is thrown if both \code{penalty} and \code{dropout} are specified for Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are: \pkg{keras} classification diff --git a/man/model_fit.Rd b/man/model_fit.Rd index 6a80cee54..d65137008 100644 --- a/man/model_fit.Rd +++ b/man/model_fit.Rd @@ -37,10 +37,12 @@ stores model objects after to seeing the data and applying a model. \examples{ # Keep the `x` matrix if the data are not too big. -spec_obj <- linear_reg(x = ifelse(.obs() < 500, TRUE, FALSE)) +spec_obj <- + linear_reg() \%>\% + set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) spec_obj -fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars, engine = "lm") +fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) fit_obj nrow(fit_obj$fit$x) diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index db9ba3614..2bcc20b3e 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -5,8 +5,7 @@ \alias{update.multinom_reg} \title{General Interface for Multinomial Regression Models} \usage{ -multinom_reg(mode = "classification", penalty = NULL, mixture = NULL, - ...) +multinom_reg(mode = "classification", penalty = NULL, mixture = NULL) \method{update}{multinom_reg}(object, penalty = NULL, mixture = NULL, fresh = FALSE, ...) @@ -23,16 +22,12 @@ represents the proportion of regularization that is used for the L2 penalty (i.e. weight decay, or ridge regression) versus L1 (the lasso) (\code{glmnet} only).} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A multinomial regression model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \description{ \code{multinom_reg} is a way to generate a \emph{specification} of a model @@ -47,7 +42,7 @@ the model. Note that this will be ignored for some engines. } These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using \code{set_engine}. If left to their defaults here (\code{NULL}), the values are taken from the underlying model functions. If parameters need to be modified, \code{update} can be used in lieu of recreating the object from scratch. @@ -80,8 +75,7 @@ reloaded and reattached to the \code{parsnip} object. Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are: \pkg{glmnet} diff --git a/man/nearest_neighbor.Rd b/man/nearest_neighbor.Rd index 5851088c9..2caf7b251 100644 --- a/man/nearest_neighbor.Rd +++ b/man/nearest_neighbor.Rd @@ -5,7 +5,7 @@ \title{General Interface for K-Nearest Neighbor Models} \usage{ nearest_neighbor(mode = "unknown", neighbors = NULL, - weight_func = NULL, dist_power = NULL, ...) + weight_func = NULL, dist_power = NULL) } \arguments{ \item{mode}{A single character string for the type of model. @@ -22,12 +22,6 @@ to weight distances between samples. Valid choices are: \code{"rectangular"}, \item{dist_power}{A single number for the parameter used in calculating Minkowski distance.} - -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} } \description{ \code{nearest_neighbor()} is a way to generate a \emph{specification} of a model @@ -45,7 +39,7 @@ and the Euclidean distance with \code{dist_power = 2}. } These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using \code{set_engine}. If left to their defaults here (\code{NULL}), the values are taken from the underlying model functions. If parameters need to be modified, \code{update()} can be used in lieu of recreating the object from scratch. @@ -68,8 +62,7 @@ on new data. This also means that a single value of that function's Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are: \pkg{kknn} (classification or regression) @@ -78,7 +71,7 @@ model, the template of the fit calls are: } \examples{ -nearest_neighbor() +nearest_neighbor(neighbors = 11) } \seealso{ diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index eb4c41f90..76a9186b1 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -52,7 +52,7 @@ the columns will be named \code{.pred_lower_classlevel} and so on. Quantile predictions return a tibble with a column \code{.pred}, which is a list-column. Each list element contains a tibble with columns -\code{.pred} and \code{.quantile} (and perhaps others). +\code{.pred} and \code{.quantile} (and perhaps other columns). Using \code{type = "raw"} with \code{predict.model_fit} (or using \code{predict_raw}) will return the unadulterated results of the @@ -85,7 +85,8 @@ library(dplyr) lm_model <- linear_reg() \%>\% - fit(mpg ~ ., data = mtcars \%>\% slice(11:32), engine = "lm") + set_engine("lm") \%>\% + fit(mpg ~ ., data = mtcars \%>\% slice(11:32)) pred_cars <- mtcars \%>\% diff --git a/man/rand_forest.Rd b/man/rand_forest.Rd index 7f5e2e604..c7f5ab20e 100644 --- a/man/rand_forest.Rd +++ b/man/rand_forest.Rd @@ -6,7 +6,7 @@ \title{General Interface for Random Forest Models} \usage{ rand_forest(mode = "unknown", mtry = NULL, trees = NULL, - min_n = NULL, ...) + min_n = NULL) \method{update}{rand_forest}(object, mtry = NULL, trees = NULL, min_n = NULL, fresh = FALSE, ...) @@ -25,16 +25,12 @@ the ensemble.} \item{min_n}{An integer for the minimum number of data points in a node that are required for the node to be split further.} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A random forest model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \description{ \code{rand_forest} is a way to generate a \emph{specification} of a model @@ -50,7 +46,7 @@ that are required for the node to be split further. } These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to their defaults +set using \code{set_engine}. If left to their defaults here (\code{NULL}), the values are taken from the underlying model functions. If parameters need to be modified, \code{update} can be used in lieu of recreating the object from scratch. @@ -62,10 +58,6 @@ following \emph{engines}: \item \pkg{R}: \code{"ranger"} or \code{"randomForest"} \item \pkg{Spark}: \code{"spark"} } - -Main parameter arguments (and those in \code{...}) can avoid -evaluation until the underlying function is executed by wrapping the -argument in \code{\link[rlang:expr]{rlang::expr()}} (e.g. \code{mtry = expr(floor(sqrt(p)))}). } \note{ For models created using the spark engine, there are @@ -85,8 +77,7 @@ reloaded and reattached to the \code{parsnip} object. Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are:: \pkg{ranger} classification diff --git a/man/set_engine.Rd b/man/set_engine.Rd new file mode 100644 index 000000000..754077732 --- /dev/null +++ b/man/set_engine.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/engines.R +\name{set_engine} +\alias{set_engine} +\title{Declare a computational engine and specific arguments} +\usage{ +set_engine(object, engine, ...) +} +\arguments{ +\item{object}{A model specification.} + +\item{engine}{A character string for the software that should +be used to fit the model. This is highly dependent on the type +of model (e.g. linear regression, random forest, etc.).} + +\item{...}{Any optional arguments associated with the chosen computational +engine. These are captured as quosures and can be \code{varying()}.} +} +\value{ +An updated model specification. +} +\description{ +\code{set_engine} is used to specify which package or system will be used +to fit the model, along with any arguments specific to that software. +} +\examples{ +# First, set general arguments using the standardized names +mod <- + logistic_reg(mixture = 1/3) \%>\% + # now say how you want to fit the model and another other options + set_engine("glmnet", nlambda = 10) +translate(mod, engine = "glmnet") +} diff --git a/man/surv_reg.Rd b/man/surv_reg.Rd index 5ef311de0..095d771a8 100644 --- a/man/surv_reg.Rd +++ b/man/surv_reg.Rd @@ -5,7 +5,7 @@ \alias{update.surv_reg} \title{General Interface for Parametric Survival Models} \usage{ -surv_reg(mode = "regression", dist = NULL, ...) +surv_reg(mode = "regression", dist = NULL) \method{update}{surv_reg}(object, dist = NULL, fresh = FALSE, ...) } @@ -16,16 +16,12 @@ The only possible value for this model is "regression".} \item{dist}{A character string for the outcome distribution. "weibull" is the default.} -\item{...}{Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the \code{update} function, the ellipses can -contain the primary arguments or any others.} - \item{object}{A survival regression model specification.} \item{fresh}{A logical for whether the arguments should be modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update}.} } \description{ \code{surv_reg} is a way to generate a \emph{specification} of a model @@ -37,7 +33,7 @@ model is: } This argument is converted to its specific names at the time that the model is fit. Other options and argument can be -set using the \code{...} slot. If left to its default +set using \code{set_engine}. If left to its default here (\code{NULL}), the value is taken from the underlying model functions. @@ -69,8 +65,7 @@ following \emph{engines}: Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the \code{...} -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are: \pkg{flexsurv} diff --git a/man/varying_args.Rd b/man/varying_args.Rd index af26f8886..ff3c73160 100644 --- a/man/varying_args.Rd +++ b/man/varying_args.Rd @@ -39,18 +39,20 @@ rand_forest() \%>\% varying_args(id = "plain") rand_forest(mtry = varying()) \%>\% varying_args(id = "one arg") -rand_forest(others = list(sample.fraction = varying())) \%>\% - varying_args(id = "only others") - -rand_forest( - others = list( - strata = expr(Class), +rand_forest() \%>\% + set_engine("ranger", sample.fraction = varying()) \%>\% + varying_args(id = "only eng_args") + +rand_forest() \%>\% + set_engine( + "ranger", + strata = expr(Class), sampsize = c(varying(), varying()) - ) -) \%>\% - varying_args(id = "add an expr") + ) \%>\% + varying_args(id = "add an expr") - rand_forest(others = list(classwt = c(class1 = 1, class2 = varying()))) \%>\% + rand_forest() \%>\% + set_engine("ranger", classwt = c(class1 = 1, class2 = varying())) \%>\% varying_args(id = "list of values") } diff --git a/tests/testthat/test_args_and_modes.R b/tests/testthat/test_args_and_modes.R index b3c9f46d7..2a0b86e63 100644 --- a/tests/testthat/test_args_and_modes.R +++ b/tests/testthat/test_args_and_modes.R @@ -7,7 +7,7 @@ context("changing arguments and engine") test_that('pipe arguments', { mod_1 <- rand_forest() %>% - set_args(mtry = 1, something = "blah") + set_args(mtry = 1) expect_equal( quo_get_expr(mod_1$args$mtry), 1 @@ -16,18 +16,9 @@ test_that('pipe arguments', { quo_get_env(mod_1$args$mtry), empty_env() ) - expect_equal( - quo_get_expr(mod_1$others$something), - "blah" - ) - expect_equal( - quo_get_env(mod_1$others$something), - empty_env() - ) - x <- 1:10 - mod_2 <- rand_forest(mtry = 2, var = x) %>% - set_args(mtry = 1, something = "blah") + mod_2 <- rand_forest(mtry = 2) %>% + set_args(mtry = 1) var_env <- rlang::current_env() @@ -39,18 +30,6 @@ test_that('pipe arguments', { quo_get_env(mod_2$args$mtry), empty_env() ) - expect_equal( - quo_get_expr(mod_2$others$something), - "blah" - ) - expect_equal( - quo_get_env(mod_2$others$something), - empty_env() - ) - expect_equal( - quo_get_env(mod_2$others$var), - var_env - ) expect_error(rand_forest() %>% set_args()) diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index 4c3a0bf91..a159387fd 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -11,8 +11,8 @@ source("helpers.R") test_that('primary arguments', { basic <- boost_tree(mode = "classification") - basic_xgboost <- translate(basic, engine = "xgboost") - basic_C5.0 <- translate(basic, engine = "C5.0") + basic_xgboost <- translate(basic %>% set_engine("xgboost")) + basic_C5.0 <- translate(basic %>% set_engine("C5.0")) expect_equal(basic_xgboost$method$fit$args, list( x = expr(missing_arg()), @@ -30,8 +30,8 @@ test_that('primary arguments', { ) trees <- boost_tree(trees = 15, mode = "classification") - trees_C5.0 <- translate(trees, engine = "C5.0") - trees_xgboost <- translate(trees, engine = "xgboost") + trees_C5.0 <- translate(trees %>% set_engine("C5.0")) + trees_xgboost <- translate(trees %>% set_engine("xgboost")) expect_equal(trees_C5.0$method$fit$args, list( x = expr(missing_arg()), @@ -51,8 +51,8 @@ test_that('primary arguments', { ) split_num <- boost_tree(min_n = 15, mode = "classification") - split_num_C5.0 <- translate(split_num, engine = "C5.0") - split_num_xgboost <- translate(split_num, engine = "xgboost") + split_num_C5.0 <- translate(split_num %>% set_engine("C5.0")) + split_num_xgboost <- translate(split_num %>% set_engine("xgboost")) expect_equal(split_num_C5.0$method$fit$args, list( x = expr(missing_arg()), @@ -74,65 +74,60 @@ test_that('primary arguments', { }) test_that('engine arguments', { - xgboost_print <- boost_tree(mode = "regression", print_every_n = 10L) - expect_equal(translate(xgboost_print, engine = "xgboost")$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - print_every_n = new_empty_quosure(10L), - nthread = 1, - verbose = 0 - ) + xgboost_print <- boost_tree(mode = "regression") + expect_equal( + translate( + xgboost_print %>% + set_engine("xgboost", print_every_n = 10L))$method$fit$args, + list( + x = expr(missing_arg()), + y = expr(missing_arg()), + print_every_n = new_empty_quosure(10L), + nthread = 1, + verbose = 0 + ) ) - C5.0_rules <- boost_tree(mode = "classification", rules = TRUE) - expect_equal(translate(C5.0_rules, engine = "C5.0")$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - rules = new_empty_quosure(TRUE) - ) + C5.0_rules <- boost_tree(mode = "classification") + expect_equal( + translate( + C5.0_rules %>% set_engine("C5.0", rules = TRUE))$method$fit$args, + list( + x = expr(missing_arg()), + y = expr(missing_arg()), + weights = expr(missing_arg()), + rules = new_empty_quosure(TRUE) + ) ) }) test_that('updating', { - expr1 <- boost_tree( verbose = 0) - expr1_exp <- boost_tree(trees = 10, verbose = 0) + expr1 <- boost_tree() %>% set_engine("xgboost", verbose = 0) + expr1_exp <- boost_tree(trees = 10) %>% set_engine("xgboost", verbose = 0) - expr2 <- boost_tree(trees = varying()) - expr2_exp <- boost_tree(trees = varying(), verbose = 0) + expr2 <- boost_tree(trees = varying()) %>% set_engine("xgboost") + expr2_exp <- boost_tree(trees = varying()) %>% set_engine("xgboost", verbose = 0) expr3 <- boost_tree(trees = 1, sample_size = varying()) expr3_exp <- boost_tree(trees = 1) - expr4 <- boost_tree(trees = 10, rules = TRUE) - expr4_exp <- boost_tree(trees = 10, rules = TRUE, earlyStopping = TRUE) - - expr5 <- boost_tree(trees = 1, rules = TRUE, earlyStopping = TRUE) - expect_equal(update(expr1, trees = 10), expr1_exp) - expect_equal(update(expr2, verbose = 0), expr2_exp) expect_equal(update(expr3, trees = 1, fresh = TRUE), expr3_exp) - expect_equal(update(expr4, rules = TRUE, earlyStopping = TRUE), expr4_exp) - expect_equal(update(expr5, rules = TRUE), expr5) - }) test_that('bad input', { expect_error(boost_tree(mode = "bogus")) expect_error({ bt <- boost_tree(trees = -1) - fit(bt, Species ~ ., iris, "xgboost") + fit(bt, Species ~ ., iris) %>% set_engine("xgboost") }) expect_error({ bt <- boost_tree(min_n = -10) - fit(bt, Species ~ ., iris, "xgboost") + fit(bt, Species ~ ., iris) %>% set_engine("xgboost") }) - expect_error(translate(boost_tree(), engine = "wat?")) - expect_warning(translate(boost_tree(), engine = NULL)) + expect_message(translate(boost_tree(), engine = NULL)) expect_error(translate(boost_tree(formula = y ~ x))) }) diff --git a/tests/testthat/test_boost_tree_C50.R b/tests/testthat/test_boost_tree_C50.R index f758d80a8..5c3e6096e 100644 --- a/tests/testthat/test_boost_tree_C50.R +++ b/tests/testthat/test_boost_tree_C50.R @@ -9,7 +9,9 @@ context("boosted tree execution with C5.0") data("lending_club") lending_club <- head(lending_club, 200) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") -lc_basic <- boost_tree(mode = "classification") +lc_basic <- + boost_tree(mode = "classification") %>% + set_engine("C5.0") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -26,8 +28,7 @@ test_that('C5.0 execution', { lc_basic, Class ~ log(funded_amnt) + int_rate, data = lending_club, - control = ctrl, - engine = "C5.0" + control = ctrl ), regexp = NA ) @@ -36,7 +37,6 @@ test_that('C5.0 execution', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "C5.0", control = ctrl ), regexp = NA @@ -55,14 +55,12 @@ test_that('C5.0 execution', { lc_basic, funded_amnt ~ term, data = lending_club, - engine = "C5.0", control = caught_ctrl ) expect_true(inherits(C5.0_form_catch$fit, "try-error")) C5.0_xy_catch <- fit_xy( lc_basic, - engine = "C5.0", control = caught_ctrl, x = lending_club[, num_pred], y = lending_club$total_bal_il @@ -78,7 +76,6 @@ test_that('C5.0 prediction', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "C5.0", control = ctrl ) @@ -95,7 +92,6 @@ test_that('C5.0 probabilities', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "C5.0", control = ctrl ) @@ -116,11 +112,9 @@ test_that('submodel prediction', { vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") class_fit <- - boost_tree(trees = 20, mode = "classification", - others = list(control = C5.0Control(earlyStopping = FALSE))) %>% - fit(churn ~ ., - data = wa_churn[-(1:4), c("churn", vars)], - engine = "C5.0") + boost_tree(trees = 20, mode = "classification") %>% + set_engine("C5.0", control = C5.0Control(earlyStopping = FALSE)) %>% + fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)]) pred_class <- predict(class_fit$fit, wa_churn[1:4, vars], trials = 4, type = "prob") diff --git a/tests/testthat/test_boost_tree_spark.R b/tests/testthat/test_boost_tree_spark.R index dc7c68818..6c517e3b8 100644 --- a/tests/testthat/test_boost_tree_spark.R +++ b/tests/testthat/test_boost_tree_spark.R @@ -30,12 +30,8 @@ test_that('spark execution', { expect_error( spark_reg_fit <- fit( - boost_tree( - trees = 5, - mode = "regression", - seed = 12 - ), - engine = "spark", + boost_tree(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), control = ctrl, Sepal_Length ~ ., data = iris_bt_tr @@ -47,12 +43,8 @@ test_that('spark execution', { expect_error( spark_reg_fit_dup <- fit( - boost_tree( - trees = 5, - mode = "regression", - seed = 12 - ), - engine = "spark", + boost_tree(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), control = ctrl, Sepal_Length ~ ., data = iris_bt_tr @@ -104,12 +96,8 @@ test_that('spark execution', { expect_error( spark_class_fit <- fit( - boost_tree( - trees = 5, - mode = "classification", - seed = 12 - ), - engine = "spark", + boost_tree(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), control = ctrl, churn ~ ., data = churn_bt_tr @@ -121,12 +109,8 @@ test_that('spark execution', { expect_error( spark_class_fit_dup <- fit( - boost_tree( - trees = 5, - mode = "classification", - seed = 12 - ), - engine = "spark", + boost_tree(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), control = ctrl, churn ~ ., data = churn_bt_tr @@ -186,7 +170,7 @@ test_that('spark execution', { regexp = NA ) - expect_equal(colnames(spark_class_prob), c("pred_0", "pred_1")) + expect_equal(colnames(spark_class_prob), c("pred_No", "pred_Yes")) expect_equivalent( as.data.frame(spark_class_prob), diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index 2c8898df1..f8a7f7aa1 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -7,7 +7,9 @@ context("boosted tree execution with xgboost") num_pred <- names(iris)[1:4] -iris_xgboost <- boost_tree(trees = 2) +iris_xgboost <- + boost_tree(trees = 2) %>% + set_engine("xgboost") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -24,7 +26,6 @@ test_that('xgboost execution, classification', { iris_xgboost, Species ~ Sepal.Width + Sepal.Length, data = iris, - engine = "xgboost", control = ctrl ), regexp = NA @@ -34,7 +35,6 @@ test_that('xgboost execution, classification', { iris_xgboost, x = iris[, num_pred], y = iris$Species, - engine = "xgboost", control = ctrl ), regexp = NA @@ -45,7 +45,6 @@ test_that('xgboost execution, classification', { iris_xgboost, Species ~ novar, data = iris, - engine = "xgboost", control = ctrl ) ) @@ -61,7 +60,6 @@ test_that('xgboost classification prediction', { iris_xgboost, x = iris[, num_pred], y = iris$Species, - engine = "xgboost", control = ctrl ) @@ -74,7 +72,6 @@ test_that('xgboost classification prediction', { iris_xgboost, Species ~ ., data = iris, - engine = "xgboost", control = ctrl ) @@ -89,12 +86,17 @@ test_that('xgboost classification prediction', { num_pred <- names(mtcars)[3:6] -car_basic <- boost_tree(mode = "regression") +car_basic <- + boost_tree(mode = "regression") %>% + set_engine("xgboost") -bad_xgboost_reg <- boost_tree(mode = "regression", - others = list(min.node.size = -10)) -bad_rf_reg <- boost_tree(mode = "regression", - others = list(sampsize = -10)) +bad_xgboost_reg <- + boost_tree(mode = "regression") %>% + set_engine("xgboost", min.node.size = -10) + +bad_rf_reg <- + boost_tree(mode = "regression") %>% + set_engine("xgboost", sampsize = -10) ctrl <- list(verbosity = 1, catch = FALSE) caught_ctrl <- list(verbosity = 1, catch = TRUE) @@ -109,7 +111,6 @@ test_that('xgboost execution, regression', { car_basic, mpg ~ ., data = mtcars, - engine = "xgboost", control = ctrl ), regexp = NA @@ -120,7 +121,6 @@ test_that('xgboost execution, regression', { car_basic, x = mtcars[, num_pred], y = mtcars$mpg, - engine = "xgboost", control = ctrl ), regexp = NA @@ -137,7 +137,6 @@ test_that('xgboost regression prediction', { car_basic, x = mtcars[, -1], y = mtcars$mpg, - engine = "xgboost", control = ctrl ) @@ -148,7 +147,6 @@ test_that('xgboost regression prediction', { car_basic, mpg ~ ., data = mtcars, - engine = "xgboost", control = ctrl ) @@ -164,11 +162,9 @@ test_that('submodel prediction', { library(xgboost) reg_fit <- - boost_tree( - trees = 20, - mode = "regression" - ) %>% - fit(mpg ~ ., data = mtcars[-(1:4), ], engine = "xgboost") + boost_tree(trees = 20, mode = "regression") %>% + set_engine("xgboost") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) x <- xgboost::xgb.DMatrix(as.matrix(mtcars[1:4, -1])) @@ -182,9 +178,8 @@ test_that('submodel prediction', { vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") class_fit <- boost_tree(trees = 20, mode = "classification") %>% - fit(churn ~ ., - data = wa_churn[-(1:4), c("churn", vars)], - engine = "xgboost") + set_engine("xgboost") %>% + fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)]) x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars])) diff --git a/tests/testthat/test_descriptors.R b/tests/testthat/test_descriptors.R index 8210f6d3d..36d1dc009 100644 --- a/tests/testthat/test_descriptors.R +++ b/tests/testthat/test_descriptors.R @@ -54,28 +54,24 @@ test_that("requires_descrs", { expect_true(parsnip:::requires_descrs(rand_forest(mtry = fn()))) expect_true(parsnip:::requires_descrs(rand_forest(mtry = fn2()))) - # descriptors in `others` - expect_false(parsnip:::requires_descrs(rand_forest(arrrg = 3))) - expect_false(parsnip:::requires_descrs(rand_forest(arrrg = varying()))) - expect_true(parsnip:::requires_descrs(rand_forest(arrrg = .obs()))) - expect_false(parsnip:::requires_descrs(rand_forest(arrrg = expr(3)))) - expect_true(parsnip:::requires_descrs(rand_forest(arrrg = fn()))) - expect_true(parsnip:::requires_descrs(rand_forest(arrrg = fn2()))) + # descriptors in `eng_args` + expect_false(parsnip:::requires_descrs(rand_forest() %>% set_engine("ranger", arrrg = 3))) + expect_false(parsnip:::requires_descrs(rand_forest() %>% set_engine("ranger", arrrg = varying()))) + expect_true(parsnip:::requires_descrs(rand_forest() %>% set_engine("ranger", arrrg = .obs()))) + expect_false(parsnip:::requires_descrs(rand_forest() %>% set_engine("ranger", arrrg = expr(3)))) + expect_true(parsnip:::requires_descrs(rand_forest() %>% set_engine("ranger", arrrg = fn()))) + expect_true(parsnip:::requires_descrs(rand_forest() %>% set_engine("ranger", arrrg = fn2()))) # mixed expect_true( parsnip:::requires_descrs( - rand_forest( - mtry = 3, - arrrg = fn2()) + rand_forest(mtry = 3) %>% set_engine("ranger", arrrg = fn2()) ) ) expect_true( parsnip:::requires_descrs( - rand_forest( - mtry = .cols(), - arrrg = 3) + rand_forest(mtry = .cols()) %>% set_engine("ranger", arrrg = 3) ) ) }) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 01f79e540..9df468114 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -11,10 +11,10 @@ source("helpers.R") test_that('primary arguments', { basic <- linear_reg() - basic_lm <- translate(basic, engine = "lm") - basic_glmnet <- translate(basic, engine = "glmnet") - basic_stan <- translate(basic, engine = "stan") - basic_spark <- translate(basic, engine = "spark") + basic_lm <- translate(basic %>% set_engine("lm")) + basic_glmnet <- translate(basic %>% set_engine("glmnet")) + basic_stan <- translate(basic %>% set_engine("stan")) + basic_spark <- translate(basic %>% set_engine("spark")) expect_equal(basic_lm$method$fit$args, list( formula = expr(missing_arg()), @@ -47,8 +47,8 @@ test_that('primary arguments', { ) mixture <- linear_reg(mixture = 0.128) - mixture_glmnet <- translate(mixture, engine = "glmnet") - mixture_spark <- translate(mixture, engine = "spark") + mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) + mixture_spark <- translate(mixture %>% set_engine("spark")) expect_equal(mixture_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -68,8 +68,8 @@ test_that('primary arguments', { ) penalty <- linear_reg(penalty = 1) - penalty_glmnet <- translate(penalty, engine = "glmnet") - penalty_spark <- translate(penalty, engine = "spark") + penalty_glmnet <- translate(penalty %>% set_engine("glmnet")) + penalty_spark <- translate(penalty %>% set_engine("spark")) expect_equal(penalty_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -89,8 +89,8 @@ test_that('primary arguments', { ) mixture_v <- linear_reg(mixture = varying()) - mixture_v_glmnet <- translate(mixture_v, engine = "glmnet") - mixture_v_spark <- translate(mixture_v, engine = "spark") + mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) + mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) expect_equal(mixture_v_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -112,8 +112,8 @@ test_that('primary arguments', { }) test_that('engine arguments', { - lm_fam <- linear_reg(model = FALSE) - expect_equal(translate(lm_fam, engine = "lm")$method$fit$args, + lm_fam <- linear_reg() %>% set_engine("lm", model = FALSE) + expect_equal(translate(lm_fam)$method$fit$args, list( formula = expr(missing_arg()), data = expr(missing_arg()), @@ -122,8 +122,8 @@ test_that('engine arguments', { ) ) - glmnet_nlam <- linear_reg(nlambda = 10) - expect_equal(translate(glmnet_nlam, engine = "glmnet")$method$fit$args, + glmnet_nlam <- linear_reg() %>% set_engine("glmnet", nlambda = 10) + expect_equal(translate(glmnet_nlam)$method$fit$args, list( x = expr(missing_arg()), y = expr(missing_arg()), @@ -133,8 +133,8 @@ test_that('engine arguments', { ) ) - stan_samp <- linear_reg(chains = 1, iter = 5) - expect_equal(translate(stan_samp, engine = "stan")$method$fit$args, + stan_samp <- linear_reg() %>% set_engine("stan", chains = 1, iter = 5) + expect_equal(translate(stan_samp)$method$fit$args, list( formula = expr(missing_arg()), data = expr(missing_arg()), @@ -145,8 +145,8 @@ test_that('engine arguments', { ) ) - spark_iter <- linear_reg(max_iter = 20) - expect_equal(translate(spark_iter, engine = "spark")$method$fit$args, + spark_iter <- linear_reg() %>% set_engine("spark", max_iter = 20) + expect_equal(translate(spark_iter)$method$fit$args, list( x = expr(missing_arg()), formula = expr(missing_arg()), @@ -159,26 +159,23 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- linear_reg( model = FALSE) - expr1_exp <- linear_reg(mixture = 0, model = FALSE) + expr1 <- linear_reg() %>% set_engine("lm", model = FALSE) + expr1_exp <- linear_reg(mixture = 0) %>% set_engine("lm", model = FALSE) - expr2 <- linear_reg(mixture = varying()) - expr2_exp <- linear_reg(mixture = varying(), nlambda = 10) + expr2 <- linear_reg(mixture = varying()) %>% set_engine("glmnet") + expr2_exp <- linear_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10) - expr3 <- linear_reg(mixture = 0, penalty = varying()) - expr3_exp <- linear_reg(mixture = 1) + expr3 <- linear_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet") + expr3_exp <- linear_reg(mixture = 1) %>% set_engine("glmnet") - expr4 <- linear_reg(mixture = 0, nlambda = 10) - expr4_exp <- linear_reg(mixture = 0, nlambda = 10, pmax = 2) + expr4 <- linear_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10) + expr4_exp <- linear_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2) - expr5 <- linear_reg(mixture = 1, nlambda = 10) - expr5_exp <- linear_reg(mixture = 1, nlambda = 10, pmax = 2) + expr5 <- linear_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10) + expr5_exp <- linear_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2) expect_equal(update(expr1, mixture = 0), expr1_exp) - expect_equal(update(expr2, nlambda = 10), expr2_exp) expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp) - expect_equal(update(expr4, pmax = 2), expr4_exp) - expect_equal(update(expr5, nlambda = 10, pmax = 2), expr5_exp) }) @@ -187,17 +184,17 @@ test_that('bad input', { # expect_error(linear_reg(penalty = -1)) # expect_error(linear_reg(mixture = -1)) expect_error(translate(linear_reg(), engine = "wat?")) - expect_warning(translate(linear_reg(), engine = NULL)) + expect_error(translate(linear_reg(), engine = NULL)) expect_error(translate(linear_reg(formula = y ~ x))) - expect_warning(translate(linear_reg(x = iris[,1:3], y = iris$Species), engine = "glmnet")) - expect_warning(translate(linear_reg(formula = y ~ x), engine = "lm")) + expect_error(translate(linear_reg(x = iris[,1:3], y = iris$Species) %>% set_engine("glmnet"))) + expect_error(translate(linear_reg(formula = y ~ x) %>% set_engine("lm"))) }) # ------------------------------------------------------------------------------ num_pred <- c("Sepal.Width", "Petal.Width", "Petal.Length") iris_bad_form <- as.formula(Species ~ term) -iris_basic <- linear_reg() +iris_basic <- linear_reg() %>% set_engine("lm") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -212,8 +209,7 @@ test_that('lm execution', { iris_basic, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - control = ctrl, - engine = "lm" + control = ctrl ), regexp = NA ) @@ -222,7 +218,6 @@ test_that('lm execution', { iris_basic, x = iris[, num_pred], y = iris$Sepal.Length, - engine = "lm", control = ctrl ), regexp = NA @@ -233,7 +228,6 @@ test_that('lm execution', { iris_basic, iris_bad_form, data = iris, - engine = "lm", control = ctrl ) ) @@ -242,7 +236,6 @@ test_that('lm execution', { iris_basic, iris_bad_form, data = iris, - engine = "lm", control = caught_ctrl ) expect_true(inherits(lm_form_catch$fit, "try-error")) @@ -254,8 +247,7 @@ test_that('lm execution', { iris_basic, cbind(Sepal.Width, Petal.Width) ~ ., data = iris, - control = ctrl, - engine = "lm" + control = ctrl ), regexp = NA ) @@ -274,7 +266,6 @@ test_that('lm prediction', { iris_basic, x = iris[, num_pred], y = iris$Sepal.Length, - engine = "lm", control = ctrl ) @@ -284,7 +275,6 @@ test_that('lm prediction', { iris_basic, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - engine = "lm", control = ctrl ) expect_equal(inl_pred, predict_num(res_form, iris[1:5, ])) @@ -293,8 +283,7 @@ test_that('lm prediction', { iris_basic, cbind(Sepal.Width, Petal.Width) ~ ., data = iris, - control = ctrl, - engine = "lm" + control = ctrl ) expect_equal(mv_pred, predict_num(res_mv, iris[1:5,])) }) @@ -308,10 +297,9 @@ test_that('lm intervals', { level = 0.93, interval = "prediction") res_xy <- fit_xy( - linear_reg(), + linear_reg() %>% set_engine("lm"), x = iris[, num_pred], y = iris$Sepal.Length, - engine = "lm", control = ctrl ) diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index 812aa8685..ba7458038 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -8,8 +8,10 @@ context("linear regression execution with glmnet") num_pred <- c("Sepal.Width", "Petal.Width", "Petal.Length") iris_bad_form <- as.formula(Species ~ term) -iris_basic <- linear_reg(penalty = .1, mixture = .3, nlambda = 15) -no_lambda <- linear_reg(mixture = .3) +iris_basic <- linear_reg(penalty = .1, mixture = .3) %>% + set_engine("glmnet", nlambda = 15) +no_lambda <- linear_reg(mixture = .3) %>% + set_engine("glmnet") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -24,7 +26,6 @@ test_that('glmnet execution', { expect_error( fit_xy( iris_basic, - engine = "glmnet", control = ctrl, x = iris[, num_pred], y = iris$Sepal.Length @@ -37,7 +38,6 @@ test_that('glmnet execution', { iris_basic, iris_bad_form, data = iris, - engine = "glm", control = ctrl ) ) @@ -46,7 +46,6 @@ test_that('glmnet execution', { iris_basic, x = iris[, num_pred], y = factor(iris$Sepal.Length), - engine = "glmnet", control = caught_ctrl ) expect_true(inherits(glmnet_xy_catch$fit, "try-error")) @@ -59,7 +58,6 @@ test_that('glmnet prediction, single lambda', { res_xy <- fit_xy( iris_basic, - engine = "glmnet", control = ctrl, x = iris[, num_pred], y = iris$Sepal.Length @@ -77,7 +75,6 @@ test_that('glmnet prediction, single lambda', { iris_basic, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - engine = "glmnet", control = ctrl ) @@ -100,11 +97,11 @@ test_that('glmnet prediction, multiple lambda', { lams <- c(.01, 0.1) - iris_mult <- linear_reg(penalty = lams, mixture = .3) + iris_mult <- linear_reg(penalty = lams, mixture = .3) %>% + set_engine("glmnet") res_xy <- fit_xy( iris_mult, - engine = "glmnet", control = ctrl, x = iris[, num_pred], y = iris$Sepal.Length @@ -124,7 +121,6 @@ test_that('glmnet prediction, multiple lambda', { iris_mult, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - engine = "glmnet", control = ctrl ) @@ -146,11 +142,11 @@ test_that('glmnet prediction, all lambda', { skip_if_not_installed("glmnet") - iris_all <- linear_reg(mixture = .3) + iris_all <- linear_reg(mixture = .3) %>% + set_engine("glmnet") res_xy <- fit_xy( iris_all, - engine = "glmnet", control = ctrl, x = iris[, num_pred], y = iris$Sepal.Length @@ -173,7 +169,6 @@ test_that('glmnet prediction, all lambda', { iris_all, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - engine = "glmnet", control = ctrl ) @@ -195,7 +190,8 @@ test_that('submodel prediction', { reg_fit <- linear_reg() %>% - fit(mpg ~ ., data = mtcars[-(1:4), ], engine = "glmnet") + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) pred_glmn <- predict(reg_fit$fit, as.matrix(mtcars[1:4, -1]), s = .1) diff --git a/tests/testthat/test_linear_reg_spark.R b/tests/testthat/test_linear_reg_spark.R index 804bbc0cc..2ed55d7d6 100644 --- a/tests/testthat/test_linear_reg_spark.R +++ b/tests/testthat/test_linear_reg_spark.R @@ -28,8 +28,7 @@ test_that('spark execution', { expect_error( spark_fit <- fit( - linear_reg(), - engine = "spark", + linear_reg() %>% set_engine("spark"), control = ctrl, Sepal_Length ~ ., data = iris_linreg_tr diff --git a/tests/testthat/test_linear_reg_stan.R b/tests/testthat/test_linear_reg_stan.R index 372468350..4f7288d90 100644 --- a/tests/testthat/test_linear_reg_stan.R +++ b/tests/testthat/test_linear_reg_stan.R @@ -8,17 +8,16 @@ context("linear regression execution with stan") num_pred <- c("Sepal.Width", "Petal.Width", "Petal.Length") iris_bad_form <- as.formula(Species ~ term) -iris_basic <- linear_reg(seed = 10, chains = 1) +iris_basic <- linear_reg() %>% + set_engine("stan", seed = 10, chains = 1) -ctrl <- fit_control(verbosity = 1, catch = FALSE) -caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) -quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) +ctrl <- fit_control(verbosity = 0L, catch = FALSE) +caught_ctrl <- fit_control(verbosity = 0L, catch = TRUE) +quiet_ctrl <- fit_control(verbosity = 0L, catch = TRUE) # ------------------------------------------------------------------------------ test_that('stan_glm execution', { - - skip("currently have an issue with environments not finding model.frame.") skip_if_not_installed("rstanarm") library(rstanarm) @@ -28,8 +27,7 @@ test_that('stan_glm execution', { iris_basic, Sepal.Width ~ log(Sepal.Length) + Species, data = iris, - control = ctrl, - engine = "stan" + control = ctrl ), regexp = NA ) @@ -38,7 +36,6 @@ test_that('stan_glm execution', { iris_basic, x = iris[, num_pred], y = iris$Sepal.Length, - engine = "stan", control = ctrl ), regexp = NA @@ -49,7 +46,6 @@ test_that('stan_glm execution', { iris_basic, Species ~ term, data = iris, - engine = "stan", control = ctrl ) ) @@ -58,8 +54,6 @@ test_that('stan_glm execution', { test_that('stan prediction', { - - skip("currently have an issue with environments not finding model.frame.") skip_if_not_installed("rstanarm") library(rstanarm) @@ -69,10 +63,10 @@ test_that('stan prediction', { inl_pred <- unname(predict(inl_stan, newdata = iris[1:5, c("Sepal.Length", "Species")])) res_xy <- fit_xy( - linear_reg(seed = 123, chains = 1), + linear_reg() %>% + set_engine("stan", seed = 10, chains = 1), x = iris[, num_pred], y = iris$Sepal.Length, - engine = "stan", control = quiet_ctrl ) @@ -82,7 +76,6 @@ test_that('stan prediction', { iris_basic, Sepal.Width ~ log(Sepal.Length) + Species, data = iris, - engine = "stan", control = quiet_ctrl ) expect_equal(inl_pred, predict_num(res_form, iris[1:5, ]), tolerance = 0.001) @@ -90,15 +83,14 @@ test_that('stan prediction', { test_that('stan intervals', { - skip("currently have an issue with environments not finding model.frame.") skip_if_not_installed("rstanarm") library(rstanarm) res_xy <- fit_xy( - linear_reg(seed = 1333, chains = 10, iter = 1000), + linear_reg() %>% + set_engine("stan", seed = 1333, chains = 10, iter = 1000), x = iris[, num_pred], y = iris$Sepal.Length, - engine = "stan", control = quiet_ctrl ) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index a0877fefb..c74b0c492 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -1,6 +1,7 @@ library(testthat) library(parsnip) library(rlang) +library(tibble) # ------------------------------------------------------------------------------ @@ -11,10 +12,10 @@ source("helpers.R") test_that('primary arguments', { basic <- logistic_reg() - basic_glm <- translate(basic, engine = "glm") - basic_glmnet <- translate(basic, engine = "glmnet") - basic_stan <- translate(basic, engine = "stan") - basic_spark <- translate(basic, engine = "spark") + basic_glm <- translate(basic %>% set_engine("glm")) + basic_glmnet <- translate(basic %>% set_engine("glmnet")) + basic_stan <- translate(basic %>% set_engine("stan")) + basic_spark <- translate(basic %>% set_engine("spark")) expect_equal(basic_glm$method$fit$args, list( formula = expr(missing_arg()), @@ -49,8 +50,8 @@ test_that('primary arguments', { ) mixture <- logistic_reg(mixture = 0.128) - mixture_glmnet <- translate(mixture, engine = "glmnet") - mixture_spark <- translate(mixture, engine = "spark") + mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) + mixture_spark <- translate(mixture %>% set_engine("spark")) expect_equal(mixture_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -71,8 +72,8 @@ test_that('primary arguments', { ) penalty <- logistic_reg(penalty = 1) - penalty_glmnet <- translate(penalty, engine = "glmnet") - penalty_spark <- translate(penalty, engine = "spark") + penalty_glmnet <- translate(penalty %>% set_engine("glmnet")) + penalty_spark <- translate(penalty %>% set_engine("spark")) expect_equal(penalty_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -93,8 +94,8 @@ test_that('primary arguments', { ) mixture_v <- logistic_reg(mixture = varying()) - mixture_v_glmnet <- translate(mixture_v, engine = "glmnet") - mixture_v_spark <- translate(mixture_v, engine = "spark") + mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) + mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) expect_equal(mixture_v_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -117,86 +118,87 @@ test_that('primary arguments', { }) test_that('engine arguments', { - glm_fam <- logistic_reg(family = binomial(link = "probit")) - expect_equal(translate(glm_fam, engine = "glm")$method$fit$args, - list( - formula = expr(missing_arg()), - data = expr(missing_arg()), - weights = expr(missing_arg()), - family = new_empty_quosure(expr(binomial(link = "probit"))) - ) - ) + glm_fam <- logistic_reg() + expect_equal( + translate( + glm_fam %>% + set_engine("glm", family = binomial(link = "probit")))$method$fit$args, + list( + formula = expr(missing_arg()), + data = expr(missing_arg()), + weights = expr(missing_arg()), + family = new_empty_quosure(expr(binomial(link = "probit"))) + ) + ) - glmnet_nlam <- logistic_reg(nlambda = 10) - expect_equal(translate(glmnet_nlam, engine = "glmnet")$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - nlambda = new_empty_quosure(10), - family = "binomial" - ) + glmnet_nlam <- logistic_reg() + expect_equal( + translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args, + list( + x = expr(missing_arg()), + y = expr(missing_arg()), + weights = expr(missing_arg()), + nlambda = new_empty_quosure(10), + family = "binomial" + ) ) - stan_samp <- logistic_reg(chains = 1, iter = 5) - expect_equal(translate(stan_samp, engine = "stan")$method$fit$args, - list( - formula = expr(missing_arg()), - data = expr(missing_arg()), - weights = expr(missing_arg()), - chains = new_empty_quosure(1), - iter = new_empty_quosure(5), - family = expr(stats::binomial) - ) + stan_samp <- logistic_reg() + expect_equal( + translate(stan_samp %>% set_engine("stan", chains = 1, iter = 5))$method$fit$args, + list( + formula = expr(missing_arg()), + data = expr(missing_arg()), + weights = expr(missing_arg()), + chains = new_empty_quosure(1), + iter = new_empty_quosure(5), + family = expr(stats::binomial) + ) ) - spark_iter <- logistic_reg(max_iter = 20) - expect_equal(translate(spark_iter, engine = "spark")$method$fit$args, - list( - x = expr(missing_arg()), - formula = expr(missing_arg()), - weight_col = expr(missing_arg()), - max_iter = new_empty_quosure(20), - family = "binomial" - ) + spark_iter <- logistic_reg() + expect_equal( + translate(spark_iter %>% set_engine("spark", max_iter = 20))$method$fit$args, + list( + x = expr(missing_arg()), + formula = expr(missing_arg()), + weight_col = expr(missing_arg()), + max_iter = new_empty_quosure(20), + family = "binomial" + ) ) }) test_that('updating', { - expr1 <- logistic_reg( family = expr(binomial(link = "probit"))) - expr1_exp <- logistic_reg(mixture = 0, family = expr(binomial(link = "probit"))) + expr1 <- logistic_reg() %>% + set_engine("glm", family = expr(binomial(link = "probit"))) + expr1_exp <- logistic_reg(mixture = 0) %>% + set_engine("glm", family = expr(binomial(link = "probit"))) - expr2 <- logistic_reg(mixture = varying()) - expr2_exp <- logistic_reg(mixture = varying(), nlambda = 10) + expr2 <- logistic_reg(mixture = varying()) %>% set_engine("glmnet") + expr2_exp <- logistic_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10) expr3 <- logistic_reg(mixture = 0, penalty = varying()) expr3_exp <- logistic_reg(mixture = 1) - expr4 <- logistic_reg(mixture = 0, nlambda = 10) - expr4_exp <- logistic_reg(mixture = 0, nlambda = 10, pmax = 2) + expr4 <- logistic_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10) + expr4_exp <- logistic_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2) - expr5 <- logistic_reg(mixture = 1, nlambda = 10) - expr5_exp <- logistic_reg(mixture = 1, nlambda = 10, pmax = 2) + expr5 <- logistic_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10) + expr5_exp <- logistic_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2) expect_equal(update(expr1, mixture = 0), expr1_exp) - expect_equal(update(expr2, nlambda = 10), expr2_exp) expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp) - expect_equal(update(expr4, pmax = 2), expr4_exp) - expect_equal(update(expr5, nlambda = 10, pmax = 2), expr5_exp) }) test_that('bad input', { expect_error(logistic_reg(mode = "regression")) - # expect_error(logistic_reg(penalty = -1)) - # expect_error(logistic_reg(mixture = -1)) - expect_error(translate(logistic_reg(), engine = "wat?")) - expect_warning(translate(logistic_reg(), engine = NULL)) expect_error(translate(logistic_reg(formula = y ~ x))) - expect_warning(translate(logistic_reg(x = iris[,1:3], y = iris$Species), engine = "glmnet")) - expect_error(translate(logistic_reg(formula = y ~ x)), engine = "glm") + expect_error(translate(logistic_reg(x = iris[,1:3], y = iris$Species) %>% set_engine(engine = "glmnet"))) + expect_error(translate(logistic_reg(formula = y ~ x) %>% set_engine(engine = "glm"))) }) # ------------------------------------------------------------------------------ @@ -205,7 +207,7 @@ data("lending_club") lending_club <- head(lending_club, 200) lc_form <- as.formula(Class ~ log(funded_amnt) + int_rate) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") -lc_basic <- logistic_reg() +lc_basic <- logistic_reg() %>% set_engine("glm") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) @@ -229,7 +231,6 @@ test_that('glm execution', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "glm", control = ctrl ), regexp = NA @@ -239,7 +240,6 @@ test_that('glm execution', { lc_basic, funded_amnt ~ term, data = lending_club, - engine = "glm", control = ctrl ) ) @@ -249,14 +249,13 @@ test_that('glm execution', { # lc_basic, # funded_amnt ~ term, # data = lending_club, - # engine = "glm", + # # control = caught_ctrl # ) # expect_true(inherits(glm_form_catch$fit, "try-error")) glm_xy_catch <- fit_xy( lc_basic, - engine = "glm", control = caught_ctrl, x = lending_club[, num_pred], y = lending_club$total_bal_il @@ -269,7 +268,6 @@ test_that('glm prediction', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "glm", control = ctrl ) @@ -286,7 +284,6 @@ test_that('glm probabilities', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "glm", control = ctrl ) @@ -313,10 +310,9 @@ test_that('glm intervals', { upper_glm <- stats_glm$family$linkinv(upper_glm) res <- fit( - logistic_reg(), + logistic_reg() %>% set_engine("glm"), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "glm", control = ctrl ) diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index 62b4b0b42..9b751cf46 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -12,7 +12,7 @@ lending_club <- head(lending_club, 200) lc_form <- as.formula(Class ~ log(funded_amnt) + int_rate) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") lc_bad_form <- as.formula(funded_amnt ~ term) -lc_basic <- logistic_reg() +lc_basic <- logistic_reg() %>% set_engine("glmnet") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -27,7 +27,6 @@ test_that('glmnet execution', { expect_error( fit_xy( lc_basic, - engine = "glmnet", control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -39,7 +38,6 @@ test_that('glmnet execution', { lc_basic, x = lending_club[, num_pred], y = lending_club$total_bal_il, - engine = "glmnet", control = caught_ctrl ) expect_true(inherits(glmnet_xy_catch$fit, "try-error")) @@ -51,8 +49,7 @@ test_that('glmnet prediction, one lambda', { skip_if_not_installed("glmnet") xy_fit <- fit_xy( - logistic_reg(penalty = 0.1), - engine = "glmnet", + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -70,10 +67,9 @@ test_that('glmnet prediction, one lambda', { expect_equal(uni_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(penalty = 0.1), + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "glmnet", control = ctrl ) @@ -100,8 +96,7 @@ test_that('glmnet prediction, mulitiple lambda', { lams <- c(0.01, 0.1) xy_fit <- fit_xy( - logistic_reg(penalty = lams), - engine = "glmnet", + logistic_reg(penalty = lams) %>% set_engine("glmnet"), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -120,10 +115,9 @@ test_that('glmnet prediction, mulitiple lambda', { expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(penalty = lams), + logistic_reg(penalty = lams) %>% set_engine("glmnet"), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "glmnet", control = ctrl ) @@ -149,8 +143,7 @@ test_that('glmnet prediction, no lambda', { skip_if_not_installed("glmnet") xy_fit <- fit_xy( - logistic_reg(nlambda = 11), - engine = "glmnet", + logistic_reg() %>% set_engine("glmnet", nlambda = 11), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -169,10 +162,9 @@ test_that('glmnet prediction, no lambda', { expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(nlambda = 11), + logistic_reg() %>% set_engine("glmnet", nlambda = 11), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "glmnet", control = ctrl ) @@ -198,8 +190,7 @@ test_that('glmnet probabilities, one lambda', { skip_if_not_installed("glmnet") xy_fit <- fit_xy( - logistic_reg(penalty = 0.1), - engine = "glmnet", + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -214,10 +205,9 @@ test_that('glmnet probabilities, one lambda', { expect_equal(uni_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(penalty = 0.1), + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "glmnet", control = ctrl ) @@ -243,8 +233,7 @@ test_that('glmnet probabilities, mulitiple lambda', { lams <- c(0.01, 0.1) xy_fit <- fit_xy( - logistic_reg(penalty = lams), - engine = "glmnet", + logistic_reg(penalty = lams) %>% set_engine("glmnet"), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -261,10 +250,9 @@ test_that('glmnet probabilities, mulitiple lambda', { expect_equal(mult_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(penalty = lams), + logistic_reg(penalty = lams) %>% set_engine("glmnet"), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "glmnet", control = ctrl ) @@ -289,8 +277,7 @@ test_that('glmnet probabilities, no lambda', { skip_if_not_installed("glmnet") xy_fit <- fit_xy( - logistic_reg(), - engine = "glmnet", + logistic_reg() %>% set_engine("glmnet"), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -307,10 +294,9 @@ test_that('glmnet probabilities, no lambda', { expect_equal(mult_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(), + logistic_reg() %>% set_engine("glmnet"), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "glmnet", control = ctrl ) @@ -337,9 +323,8 @@ test_that('submodel prediction', { vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") class_fit <- logistic_reg() %>% - fit(churn ~ ., - data = wa_churn[-(1:4), c("churn", vars)], - engine = "glmnet") + set_engine("glmnet") %>% + fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)]) pred_glmn <- predict(class_fit$fit, as.matrix(wa_churn[1:4, vars]), s = .1, type = "response") diff --git a/tests/testthat/test_logistic_reg_spark.R b/tests/testthat/test_logistic_reg_spark.R index c7dbf09fb..9d435dc85 100644 --- a/tests/testthat/test_logistic_reg_spark.R +++ b/tests/testthat/test_logistic_reg_spark.R @@ -30,8 +30,7 @@ test_that('spark execution', { expect_error( spark_class_fit <- fit( - logistic_reg(), - engine = "spark", + logistic_reg() %>% set_engine("spark"), control = ctrl, churn ~ ., data = churn_logit_tr @@ -43,8 +42,7 @@ test_that('spark execution', { expect_error( spark_class_fit_dup <- fit( - logistic_reg(), - engine = "spark", + logistic_reg() %>% set_engine("spark"), control = ctrl, churn ~ ., data = churn_logit_tr @@ -79,7 +77,7 @@ test_that('spark execution', { regexp = NA ) - expect_equal(colnames(spark_class_prob), c("pred_Yes", "pred_No")) + expect_equal(colnames(spark_class_prob), c("pred_No", "pred_Yes")) expect_equivalent( as.data.frame(spark_class_prob), diff --git a/tests/testthat/test_logistic_reg_stan.R b/tests/testthat/test_logistic_reg_stan.R index e822bfd77..648fa2442 100644 --- a/tests/testthat/test_logistic_reg_stan.R +++ b/tests/testthat/test_logistic_reg_stan.R @@ -11,17 +11,17 @@ data("lending_club") lending_club <- head(lending_club, 200) lc_form <- as.formula(Class ~ log(funded_amnt) + int_rate) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") -lc_basic <- logistic_reg(seed = 1333, chains = 1) +lc_basic <- + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1) -ctrl <- fit_control(verbosity = 1, catch = FALSE) -caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) +ctrl <- fit_control(verbosity = 0, catch = FALSE) +caught_ctrl <- fit_control(verbosity = 0, catch = TRUE) quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) # ------------------------------------------------------------------------------ test_that('stan_glm execution', { - skip("currently have an issue with environments not finding model.frame.") - skip_if_not_installed("rstanarm") expect_error( @@ -29,14 +29,12 @@ test_that('stan_glm execution', { lc_basic, funded_amnt ~ term, data = lending_club, - engine = "stan", control = ctrl ) ) stan_xy_catch <- fit_xy( lc_basic, - engine = "stan", control = caught_ctrl, x = lending_club[, num_pred], y = lending_club$total_bal_il @@ -47,14 +45,12 @@ test_that('stan_glm execution', { test_that('stan_glm prediction', { - skip("currently have an issue with environments not finding model.frame.") - skip_if_not_installed("rstanarm") library(rstanarm) xy_fit <- fit_xy( - logistic_reg(seed = 11, chains = 1), - engine = "stan", + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -71,10 +67,10 @@ test_that('stan_glm prediction', { expect_equal(xy_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(seed = 11, chains = 1), + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "stan", control = ctrl ) @@ -91,13 +87,11 @@ test_that('stan_glm prediction', { test_that('stan_glm probability', { - skip("currently have an issue with environments not finding model.frame.") - skip_if_not_installed("rstanarm") xy_fit <- fit_xy( - logistic_reg(seed = 11, chains = 1), - engine = "stan", + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -112,10 +106,10 @@ test_that('stan_glm probability', { expect_equal(xy_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( - logistic_reg(seed = 11, chains = 1), + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "stan", control = ctrl ) @@ -129,16 +123,14 @@ test_that('stan_glm probability', { test_that('stan intervals', { - skip("currently have an issue with environments not finding model.frame.") - skip_if_not_installed("rstanarm") library(rstanarm) res_form <- fit( - logistic_reg(seed = 11, chains = 1), + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), Class ~ log(funded_amnt) + int_rate, data = lending_club, - engine = "stan", control = ctrl ) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index cdefc41ce..bf3d35c11 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -1,5 +1,4 @@ library(testthat) - library(parsnip) library(rlang) @@ -12,7 +11,7 @@ source("helpers.R") test_that('primary arguments', { basic <- mars(mode = "regression") - basic_mars <- translate(basic, engine = "earth") + basic_mars <- translate(basic %>% set_engine("earth")) expect_equal(basic_mars$method$fit$args, list( x = expr(missing_arg()), @@ -23,7 +22,7 @@ test_that('primary arguments', { ) num_terms <- mars(num_terms = 4, mode = "classification") - num_terms_mars <- translate(num_terms, engine = "earth") + num_terms_mars <- translate(num_terms %>% set_engine("earth")) expect_equal(num_terms_mars$method$fit$args, list( x = expr(missing_arg()), @@ -36,7 +35,7 @@ test_that('primary arguments', { ) prod_degree <- mars(prod_degree = 1, mode = "regression") - prod_degree_mars <- translate(prod_degree, engine = "earth") + prod_degree_mars <- translate(prod_degree %>% set_engine("earth")) expect_equal(prod_degree_mars$method$fit$args, list( x = expr(missing_arg()), @@ -48,7 +47,7 @@ test_that('primary arguments', { ) prune_method_v <- mars(prune_method = varying(), mode = "regression") - prune_method_v_mars <- translate(prune_method_v, engine = "earth") + prune_method_v_mars <- translate(prune_method_v %>% set_engine("earth")) expect_equal(prune_method_v_mars$method$fit$args, list( x = expr(missing_arg()), @@ -61,8 +60,8 @@ test_that('primary arguments', { }) test_that('engine arguments', { - mars_keep <- mars(mode = "regression", keepxy = FALSE) - expect_equal(translate(mars_keep, engine = "earth")$method$fit$args, + mars_keep <- mars(mode = "regression") + expect_equal(translate(mars_keep %>% set_engine("earth", keepxy = FALSE))$method$fit$args, list( x = expr(missing_arg()), y = expr(missing_arg()), @@ -74,39 +73,35 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- mars( model = FALSE) - expr1_exp <- mars(num_terms = 1, model = FALSE) + expr1 <- mars() %>% set_engine("earth", model = FALSE) + expr1_exp <- mars(num_terms = 1) %>% set_engine("earth", model = FALSE) - expr2 <- mars(num_terms = varying()) - expr2_exp <- mars(num_terms = varying(), nk = 10) + expr2 <- mars(num_terms = varying()) %>% set_engine("earth") + expr2_exp <- mars(num_terms = varying()) %>% set_engine("earth", nk = 10) - expr3 <- mars(num_terms = 1, prod_degree = varying()) - expr3_exp <- mars(num_terms = 1) + expr3 <- mars(num_terms = 1, prod_degree = varying()) %>% set_engine("earth") + expr3_exp <- mars(num_terms = 1) %>% set_engine("earth") - expr4 <- mars(num_terms = 0, nk = 10) - expr4_exp <- mars(num_terms = 0, nk = 10, trace = 2) + expr4 <- mars(num_terms = 0) %>% set_engine("earth", nk = 10) + expr4_exp <- mars(num_terms = 0) %>% set_engine("earth", nk = 10, trace = 2) - expr5 <- mars(num_terms = 1, nk = 10) - expr5_exp <- mars(num_terms = 1, nk = 10, trace = 2) + expr5 <- mars(num_terms = 1) %>% set_engine("earth", nk = 10) + expr5_exp <- mars(num_terms = 1) %>% set_engine("earth", nk = 10, trace = 2) expect_equal(update(expr1, num_terms = 1), expr1_exp) - expect_equal(update(expr2, nk = 10), expr2_exp) expect_equal(update(expr3, num_terms = 1, fresh = TRUE), expr3_exp) - expect_equal(update(expr4, trace = 2), expr4_exp) - expect_equal(update(expr5, nk = 10, trace = 2), expr5_exp) - }) test_that('bad input', { # expect_error(mars(prod_degree = -1)) # expect_error(mars(num_terms = -1)) - expect_error(translate(mars(), engine = "wat?")) - expect_warning(translate(mars(mode = "regression"), engine = NULL)) + expect_error(translate(mars() %>% set_engine("wat?"))) + expect_error(translate(mars(mode = "regression") %>% set_engine())) expect_error(translate(mars(formula = y ~ x))) expect_warning( translate( - mars(mode = "regression", x = iris[,1:3], y = iris$Species), - engine = "earth") + mars(mode = "regression") %>% set_engine("earth", x = iris[,1:3], y = iris$Species) + ) ) }) @@ -114,7 +109,7 @@ test_that('bad input', { num_pred <- c("Sepal.Width", "Petal.Width", "Petal.Length") iris_bad_form <- as.formula(Species ~ term) -iris_basic <- mars(mode = "regression") +iris_basic <- mars(mode = "regression") %>% set_engine("earth") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -123,7 +118,6 @@ quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) # ------------------------------------------------------------------------------ test_that('mars execution', { - skip("currently have an issue with environments not finding model.frame.") skip_if_not_installed("earth") expect_error( @@ -131,8 +125,7 @@ test_that('mars execution', { iris_basic, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - control = ctrl, - engine = "earth" + control = ctrl ), regexp = NA ) @@ -141,7 +134,6 @@ test_that('mars execution', { iris_basic, x = iris[, num_pred], y = iris$Sepal.Length, - engine = "earth", control = ctrl ), regexp = NA @@ -152,7 +144,6 @@ test_that('mars execution', { iris_basic, iris_bad_form, data = iris, - engine = "earth", control = ctrl ) ) @@ -164,8 +155,7 @@ test_that('mars execution', { iris_basic, cbind(Sepal.Width, Petal.Width) ~ ., data = iris, - control = ctrl, - engine = "earth" + control = ctrl ), regexp = NA ) @@ -173,7 +163,6 @@ test_that('mars execution', { }) test_that('mars prediction', { - skip("currently have an issue with environments not finding model.frame.") skip_if_not_installed("earth") library(earth) @@ -188,7 +177,6 @@ test_that('mars prediction', { iris_basic, x = iris[, num_pred], y = iris$Sepal.Length, - engine = "earth", control = ctrl ) @@ -198,7 +186,6 @@ test_that('mars prediction', { iris_basic, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - engine = "earth", control = ctrl ) expect_equal(inl_pred, predict_num(res_form, iris[1:5, ])) @@ -207,47 +194,54 @@ test_that('mars prediction', { iris_basic, cbind(Sepal.Width, Petal.Width) ~ ., data = iris, - control = ctrl, - engine = "earth" + control = ctrl ) expect_equal(mv_pred, predict_num(res_mv, iris[1:5,])) }) test_that('submodel prediction', { - skip("currently have an issue with environments not finding model.frame.") skip_if_not_installed("earth") library(earth) reg_fit <- mars( num_terms = 20, - prune_method = "none", mode = "regression", - keepxy = TRUE + prune_method = "none" ) %>% - fit(mpg ~ ., data = mtcars[-(1:4), ], engine = "earth") + set_engine("earth", keepxy = TRUE) %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) + + tmp_reg <- reg_fit$fit + tmp_reg$call[["pmethod"]] <- eval_tidy(tmp_reg$call[["pmethod"]]) + tmp_reg$call[["keepxy"]] <- eval_tidy(tmp_reg$call[["keepxy"]]) + tmp_reg$call[["nprune"]] <- eval_tidy(tmp_reg$call[["nprune"]]) - pruned_fit <- update(reg_fit$fit, nprune = 5) - pruned_pred <- predict(pruned_fit, mtcars[1:4, -1])[,1] + pruned_reg <- update(tmp_reg, nprune = 5) + pruned_reg_pred <- predict(pruned_reg, mtcars[1:4, -1])[,1] mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], num_terms = 5) mp_res <- do.call("rbind", mp_res$.pred) - expect_equal(mp_res[[".pred"]], pruned_pred) + expect_equal(mp_res[[".pred"]], pruned_reg_pred) vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") class_fit <- - mars(mode = "classification", prune_method = "none", keepxy = TRUE) %>% + mars(mode = "classification", prune_method = "none") %>% + set_engine("earth", keepxy = TRUE) %>% fit(churn ~ ., - data = wa_churn[-(1:4), c("churn", vars)], - engine = "earth") + data = wa_churn[-(1:4), c("churn", vars)]) + + cls_fit <- class_fit$fit + cls_fit$call[["pmethod"]] <- eval_tidy(cls_fit$call[["pmethod"]]) + cls_fit$call[["keepxy"]] <- eval_tidy(cls_fit$call[["keepxy"]]) - pruned_fit <- update(class_fit$fit, nprune = 5) - pruned_pred <- predict(pruned_fit, wa_churn[1:4, vars], type = "response")[,1] + pruned_cls <- update(cls_fit, nprune = 5) + pruned_cls_pred <- predict(pruned_cls, wa_churn[1:4, vars], type = "response")[,1] mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], num_terms = 5, type = "prob") mp_res <- do.call("rbind", mp_res$.pred) - expect_equal(mp_res[[".pred_No"]], pruned_pred) + expect_equal(mp_res[[".pred_No"]], pruned_cls_pred) }) @@ -256,12 +250,12 @@ test_that('submodel prediction', { data("lending_club") test_that('classification', { - skip("currently have an issue with environments not finding model.frame.") skip_if_not_installed("earth") expect_error( - glm_mars <- mars(mode = "classification") %>% - fit(Class ~ ., data = lending_club[-(1:5),], engine = "earth"), + glm_mars <- mars(mode = "classification") %>% + set_engine("earth") %>% + fit(Class ~ ., data = lending_club[-(1:5),]), regexp = NA ) expect_true(!is.null(glm_mars$fit$glm.list)) diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index c04560ec7..46006c600 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -12,8 +12,8 @@ source("helpers.R") test_that('primary arguments', { hidden_units <- mlp(mode = "regression", hidden_units = 4) - hidden_units_nnet <- translate(hidden_units, engine = "nnet") - hidden_units_keras <- translate(hidden_units, engine = "keras") + hidden_units_nnet <- translate(hidden_units %>% set_engine("nnet")) + hidden_units_keras <- translate(hidden_units %>% set_engine("keras")) expect_equal(hidden_units_nnet$method$fit$args, list( formula = expr(missing_arg()), @@ -33,7 +33,7 @@ test_that('primary arguments', { ) no_hidden_units <- mlp(mode = "regression") - no_hidden_units_nnet <- translate(no_hidden_units, engine = "nnet") + no_hidden_units_nnet <- translate(no_hidden_units %>% set_engine("nnet")) expect_equal(no_hidden_units_nnet$method$fit$args, list( formula = expr(missing_arg()), @@ -58,8 +58,8 @@ test_that('primary arguments', { epochs = 2, hidden_units = 4, penalty = 0.0001, dropout = 0, activation = "softmax" ) - all_args_nnet <- translate(all_args, engine = "nnet") - all_args_keras <- translate(all_args, engine = "keras") + all_args_nnet <- translate(all_args %>% set_engine("nnet")) + all_args_keras <- translate(all_args %>% set_engine("keras")) expect_equal(all_args_nnet$method$fit$args, list( formula = expr(missing_arg()), @@ -87,8 +87,8 @@ test_that('primary arguments', { }) test_that('engine arguments', { - nnet_hess <- mlp(mode = "classification", Hess = TRUE) - expect_equal(translate(nnet_hess, engine = "nnet")$method$fit$args, + nnet_hess <- mlp(mode = "classification") %>% set_engine("nnet", Hess = TRUE) + expect_equal(translate(nnet_hess)$method$fit$args, list( formula = expr(missing_arg()), data = expr(missing_arg()), @@ -100,8 +100,8 @@ test_that('engine arguments', { ) ) - keras_val <- mlp(mode = "regression", validation_split = 0.2) - expect_equal(translate(keras_val, engine = "keras")$method$fit$args, + keras_val <- mlp(mode = "regression") %>% set_engine("keras", validation_split = 0.2) + expect_equal(translate(keras_val)$method$fit$args, list( x = expr(missing_arg()), y = expr(missing_arg()), @@ -110,8 +110,8 @@ test_that('engine arguments', { ) - nnet_tol <- mlp(mode = "regression", abstol = varying()) - expect_equal(translate(nnet_tol, engine = "nnet")$method$fit$args, + nnet_tol <- mlp(mode = "regression") %>% set_engine("nnet", abstol = varying()) + expect_equal(translate(nnet_tol)$method$fit$args, list( formula = expr(missing_arg()), data = expr(missing_arg()), @@ -126,36 +126,34 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- mlp(mode = "regression", Hess = FALSE, abstol = varying()) - expr1_exp <- mlp(mode = "regression", hidden_units = 2, Hess = FALSE, abstol = varying()) + expr1 <- mlp(mode = "regression") %>% + set_engine("nnet", Hess = FALSE, abstol = varying()) + expr1_exp <- mlp(mode = "regression", hidden_units = 2) %>% + set_engine("nnet", Hess = FALSE, abstol = varying()) - expr2 <- mlp(mode = "regression", hidden_units = 7) - expr2_exp <- mlp(mode = "regression", hidden_units = 7, Hess = FALSE) + expr2 <- mlp(mode = "regression", hidden_units = 7) %>% set_engine("nnet") + expr2_exp <- mlp(mode = "regression", hidden_units = 7) %>% set_engine("nnet", Hess = FALSE) - expr3 <- mlp(mode = "regression", hidden_units = 7, epochs = varying()) + expr3 <- mlp(mode = "regression", hidden_units = 7, epochs = varying()) %>% set_engine("keras") - expr3_exp <- mlp(mode = "regression", hidden_units = 2) + expr3_exp <- mlp(mode = "regression", hidden_units = 2) %>% set_engine("keras") - expr4 <- mlp(mode = "classification", hidden_units = 2, Hess = TRUE, abstol = varying()) - expr4_exp <- mlp(mode = "classification", hidden_units = 2, Hess = FALSE, abstol = varying()) + expr4 <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying()) + expr4_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying()) - expr5 <- mlp(mode = "classification", hidden_units = 2, Hess = FALSE) - expr5_exp <- mlp(mode = "classification", hidden_units = 2, Hess = FALSE, abstol = varying()) + expr5 <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE) + expr5_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying()) expect_equal(update(expr1, hidden_units = 2), expr1_exp) - expect_equal(update(expr2,Hess = FALSE), expr2_exp) expect_equal(update(expr3, hidden_units = 2, fresh = TRUE), expr3_exp) - expect_equal(update(expr4,Hess = FALSE), expr4_exp) - expect_equal(update(expr5,Hess = FALSE, abstol = varying()), expr5_exp) }) test_that('bad input', { expect_error(mlp(mode = "time series")) - expect_error(translate(mlp(mode = "classification"), engine = "wat?")) - expect_error(translate(mlp(mode = "classification",ytest = 2))) - expect_error(translate(mlp(mode = "regression", formula = y ~ x))) - expect_warning(translate(mlp(mode = "classification", x = x, y = y), engine = "keras")) - expect_error(translate(mlp(mode = "regression", formula = y ~ x), engine = "")) + expect_error(translate(mlp(mode = "classification") %>% set_engine("wat?"))) + expect_warning(translate(mlp(mode = "regression") %>% set_engine("nnet", formula = y ~ x))) + expect_error(translate(mlp(mode = "classification", x = x, y = y) %>% set_engine("keras"))) + expect_error(translate(mlp(mode = "regression", formula = y ~ x) %>% set_engine())) }) diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index 335bb3d1d..9e1a55448 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -8,7 +8,9 @@ context("simple neural network execution with keras") num_pred <- names(iris)[1:4] -iris_keras <- mlp(mode = "classification", hidden_units = 2, verbose = 0, epochs = 10) +iris_keras <- + mlp(mode = "classification", hidden_units = 2, epochs = 10) %>% + set_engine("keras", verbose = 0) ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -25,7 +27,6 @@ test_that('keras execution, classification', { iris_keras, Species ~ Sepal.Width + Sepal.Length, data = iris, - engine = "keras", control = ctrl ), regexp = NA @@ -38,7 +39,6 @@ test_that('keras execution, classification', { iris_keras, x = iris[, num_pred], y = iris$Species, - engine = "keras", control = ctrl ), regexp = NA @@ -51,7 +51,6 @@ test_that('keras execution, classification', { iris_keras, Species ~ novar, data = iris, - engine = "keras", control = ctrl ) ) @@ -67,7 +66,6 @@ test_that('keras classification prediction', { iris_keras, x = iris[, num_pred], y = iris$Species, - engine = "keras", control = ctrl ) @@ -81,7 +79,6 @@ test_that('keras classification prediction', { iris_keras, Species ~ ., data = iris, - engine = "keras", control = ctrl ) @@ -101,7 +98,6 @@ test_that('keras classification probabilities', { iris_keras, x = iris[, num_pred], y = iris$Species, - engine = "keras", control = ctrl ) @@ -116,7 +112,6 @@ test_that('keras classification probabilities', { iris_keras, Species ~ ., data = iris, - engine = "keras", control = ctrl ) @@ -135,9 +130,12 @@ mtcars <- as.data.frame(scale(mtcars)) num_pred <- names(mtcars)[3:6] -car_basic <- mlp(mode = "regression", verbose = 0, epochs = 10) +car_basic <- mlp(mode = "regression", epochs = 10) %>% + set_engine("keras", verbose = 0) -bad_keras_reg <- mlp(mode = "regression", min.node.size = -10) +bad_keras_reg <- + mlp(mode = "regression") %>% + set_engine("keras", min.node.size = -10) ctrl <- list(verbosity = 1, catch = FALSE) caught_ctrl <- list(verbosity = 1, catch = TRUE) @@ -155,7 +153,6 @@ test_that('keras execution, regression', { car_basic, mpg ~ ., data = mtcars, - engine = "keras", control = ctrl ), regexp = NA @@ -168,7 +165,6 @@ test_that('keras execution, regression', { car_basic, x = mtcars[, num_pred], y = mtcars$mpg, - engine = "keras", control = ctrl ), regexp = NA @@ -180,10 +176,10 @@ test_that('keras regression prediction', { skip_if_not_installed("keras") xy_fit <- parsnip::fit_xy( - mlp(mode = "regression", hidden_units = 2, epochs = 500, penalty = .1, verbose = 0), + mlp(mode = "regression", hidden_units = 2, epochs = 500, penalty = .1) %>% + set_engine("keras", verbose = 0), x = mtcars[, c("cyl", "disp")], y = mtcars$mpg, - engine = "keras", control = ctrl ) @@ -195,8 +191,7 @@ test_that('keras regression prediction', { form_fit <- parsnip::fit( car_basic, mpg ~ ., - data = mtcars[, c("cyl", "disp", "mpg")],, - engine = "keras", + data = mtcars[, c("cyl", "disp", "mpg")], control = ctrl ) @@ -216,16 +211,11 @@ test_that('multivariate nnet formula', { skip_if_not_installed("keras") nnet_form <- - mlp( - mode = "regression", - hidden_units = 3, - penalty = 0.01, - verbose = 0 - ) %>% + mlp(mode = "regression", hidden_units = 3, penalty = 0.01) %>% + set_engine("keras", verbose = 0) %>% parsnip::fit( cbind(V1, V2, V3) ~ ., - data = nn_dat[-(1:5),], - engine = "keras" + data = nn_dat[-(1:5),] ) expect_equal(length(unlist(keras::get_weights(nnet_form$fit))), 24) nnet_form_pred <- predict_num(nnet_form, new_data = nn_dat[1:5, -(1:3)]) @@ -236,16 +226,11 @@ test_that('multivariate nnet formula', { keras::backend()$clear_session() nnet_xy <- - mlp( - mode = "regression", - hidden_units = 3, - penalty = 0.01, - verbose = 0 - ) %>% + mlp(mode = "regression", hidden_units = 3, penalty = 0.01) %>% + set_engine("keras", verbose = 0) %>% parsnip::fit_xy( x = nn_dat[-(1:5), -(1:3)], - y = nn_dat[-(1:5), 1:3 ], - engine = "keras" + y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(unlist(keras::get_weights(nnet_xy$fit))), 24) nnet_form_xy <- predict_num(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) diff --git a/tests/testthat/test_mlp_nnet.R b/tests/testthat/test_mlp_nnet.R index b0fa3d7b8..24dd3fa9e 100644 --- a/tests/testthat/test_mlp_nnet.R +++ b/tests/testthat/test_mlp_nnet.R @@ -7,7 +7,9 @@ context("simple neural network execution with nnet") num_pred <- names(iris)[1:4] -iris_nnet <- mlp(mode = "classification", hidden_units = 2) +iris_nnet <- + mlp(mode = "classification", hidden_units = 2) %>% + set_engine("nnet") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -24,7 +26,6 @@ test_that('nnet execution, classification', { iris_nnet, Species ~ Sepal.Width + Sepal.Length, data = iris, - engine = "nnet", control = ctrl ), regexp = NA @@ -34,7 +35,6 @@ test_that('nnet execution, classification', { iris_nnet, x = iris[, num_pred], y = iris$Species, - engine = "nnet", control = ctrl ), regexp = NA @@ -45,7 +45,6 @@ test_that('nnet execution, classification', { iris_nnet, Species ~ novar, data = iris, - engine = "nnet", control = ctrl ) ) @@ -60,7 +59,6 @@ test_that('nnet classification prediction', { iris_nnet, x = iris[, num_pred], y = iris$Species, - engine = "nnet", control = ctrl ) @@ -72,7 +70,6 @@ test_that('nnet classification prediction', { iris_nnet, Species ~ ., data = iris, - engine = "nnet", control = ctrl ) @@ -86,10 +83,16 @@ test_that('nnet classification prediction', { num_pred <- names(mtcars)[3:6] -car_basic <- mlp(mode = "regression") +car_basic <- + mlp(mode = "regression") %>% + set_engine("nnet") -bad_nnet_reg <- mlp(mode = "regression", min.node.size = -10) -bad_rf_reg <- mlp(mode = "regression", sampsize = -10) +bad_nnet_reg <- + mlp(mode = "regression") %>% + set_engine("nnet", min.node.size = -10) +bad_rf_reg <- + mlp(mode = "regression") %>% + set_engine("nnet", sampsize = -10) ctrl <- list(verbosity = 1, catch = FALSE) caught_ctrl <- list(verbosity = 1, catch = TRUE) @@ -107,7 +110,6 @@ test_that('nnet execution, regression', { car_basic, mpg ~ ., data = mtcars, - engine = "nnet", control = ctrl ), regexp = NA @@ -118,7 +120,6 @@ test_that('nnet execution, regression', { car_basic, x = mtcars[, num_pred], y = mtcars$mpg, - engine = "nnet", control = ctrl ), regexp = NA @@ -135,7 +136,6 @@ test_that('nnet regression prediction', { car_basic, x = mtcars[, -1], y = mtcars$mpg, - engine = "nnet", control = ctrl ) @@ -147,7 +147,6 @@ test_that('nnet regression prediction', { car_basic, mpg ~ ., data = mtcars, - engine = "nnet", control = ctrl ) @@ -169,11 +168,11 @@ test_that('multivariate nnet formula', { mode = "regression", hidden_units = 3, penalty = 0.01 - ) %>% + ) %>% + set_engine("nnet") %>% parsnip::fit( cbind(V1, V2, V3) ~ ., - data = nn_dat[-(1:5),], - engine = "nnet" + data = nn_dat[-(1:5),] ) expect_equal(length(nnet_form$fit$wts), 24) nnet_form_pred <- predict_num(nnet_form, new_data = nn_dat[1:5, -(1:3)]) @@ -187,10 +186,10 @@ test_that('multivariate nnet formula', { hidden_units = 3, penalty = 0.01 ) %>% + set_engine("nnet") %>% parsnip::fit_xy( x = nn_dat[-(1:5), -(1:3)], - y = nn_dat[-(1:5), 1:3 ], - engine = "nnet" + y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(nnet_xy$fit$wts), 24) nnet_form_xy <- predict_num(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 74c67a1e4..31b8c72bb 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -11,7 +11,7 @@ source("helpers.R") test_that('primary arguments', { basic <- multinom_reg() - basic_glmnet <- translate(basic, engine = "glmnet") + basic_glmnet <- translate(basic %>% set_engine("glmnet")) expect_equal(basic_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -22,7 +22,7 @@ test_that('primary arguments', { ) mixture <- multinom_reg(mixture = 0.128) - mixture_glmnet <- translate(mixture, engine = "glmnet") + mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) expect_equal(mixture_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -34,7 +34,7 @@ test_that('primary arguments', { ) penalty <- multinom_reg(penalty = 1) - penalty_glmnet <- translate(penalty, engine = "glmnet") + penalty_glmnet <- translate(penalty %>% set_engine("glmnet")) expect_equal(penalty_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -46,7 +46,7 @@ test_that('primary arguments', { ) mixture_v <- multinom_reg(mixture = varying()) - mixture_v_glmnet <- translate(mixture_v, engine = "glmnet") + mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) expect_equal(mixture_v_glmnet$method$fit$args, list( x = expr(missing_arg()), @@ -60,50 +60,48 @@ test_that('primary arguments', { }) test_that('engine arguments', { - glmnet_nlam <- multinom_reg(nlambda = 10) - expect_equal(translate(glmnet_nlam, engine = "glmnet")$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - nlambda = new_empty_quosure(10), - family = "multinomial" - ) + glmnet_nlam <- multinom_reg() + expect_equal( + translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args, + list( + x = expr(missing_arg()), + y = expr(missing_arg()), + weights = expr(missing_arg()), + nlambda = new_empty_quosure(10), + family = "multinomial" + ) ) }) test_that('updating', { - expr1 <- multinom_reg( intercept = TRUE) - expr1_exp <- multinom_reg(mixture = 0, intercept = TRUE) + expr1 <- multinom_reg() %>% set_engine("glmnet", intercept = TRUE) + expr1_exp <- multinom_reg(mixture = 0) %>% set_engine("glmnet", intercept = TRUE) - expr2 <- multinom_reg(mixture = varying()) - expr2_exp <- multinom_reg(mixture = varying(), nlambda = 10) + expr2 <- multinom_reg(mixture = varying()) %>% set_engine("glmnet") + expr2_exp <- multinom_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10) - expr3 <- multinom_reg(mixture = 0, penalty = varying()) - expr3_exp <- multinom_reg(mixture = 1) + expr3 <- multinom_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet") + expr3_exp <- multinom_reg(mixture = 1) %>% set_engine("glmnet") - expr4 <- multinom_reg(mixture = 0, nlambda = 10) - expr4_exp <- multinom_reg(mixture = 0, nlambda = 10, pmax = 2) + expr4 <- multinom_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10) + expr4_exp <- multinom_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2) - expr5 <- multinom_reg(mixture = 1, nlambda = 10) - expr5_exp <- multinom_reg(mixture = 1, nlambda = 10, pmax = 2) + expr5 <- multinom_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10) + expr5_exp <- multinom_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2) - expect_equal(update(expr1, mixture = 0), expr1_exp) - expect_equal(update(expr2, nlambda = 10), expr2_exp) - expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp) - expect_equal(update(expr4, pmax = 2), expr4_exp) - expect_equal(update(expr5, nlambda = 10, pmax = 2), expr5_exp) + # expect_equal(update(expr1 %>% set_engine("glmnet"), mixture = 0), expr1_exp) + expect_equal(update(expr2) %>% set_engine("glmnet", nlambda = 10), expr2_exp) + expect_equal(update(expr3, mixture = 1, fresh = TRUE) %>% set_engine("glmnet"), expr3_exp) + # expect_equal(update(expr4 %>% set_engine("glmnet", pmax = 2)), expr4_exp) + expect_equal(update(expr5) %>% set_engine("glmnet", nlambda = 10, pmax = 2), expr5_exp) }) test_that('bad input', { expect_error(multinom_reg(mode = "regression")) - # expect_error(multinom_reg(penalty = -1)) - # expect_error(multinom_reg(mixture = -1)) - expect_error(translate(multinom_reg(), engine = "wat?")) - expect_warning(translate(multinom_reg(), engine = NULL)) - expect_error(translate(multinom_reg(formula = y ~ x))) - expect_warning(translate(multinom_reg(x = iris[,1:3], y = iris$Species), engine = "glmnet")) + expect_error(translate(multinom_reg() %>% set_engine("wat?"))) + expect_error(translate(multinom_reg() %>% set_engine())) + expect_warning(translate(multinom_reg() %>% set_engine("glmnet", x = iris[,1:3], y = iris$Species))) }) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index bf45b7310..c658feb6f 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -21,8 +21,7 @@ test_that('glmnet execution', { expect_error( fit_xy( - multinom_reg(), - engine = "glmnet", + multinom_reg() %>% set_engine("glmnet"), control = ctrl, x = iris[, 1:4], y = iris$Species @@ -31,10 +30,10 @@ test_that('glmnet execution', { ) glmnet_xy_catch <- fit_xy( - multinom_reg(), + multinom_reg() %>% set_engine("glmnet"), x = iris[, 2:5], y = iris$Sepal.Length, - engine = "glmnet", + , control = caught_ctrl ) expect_true(inherits(glmnet_xy_catch$fit, "try-error")) @@ -46,8 +45,7 @@ test_that('glmnet prediction, one lambda', { skip_if_not_installed("glmnet") xy_fit <- fit_xy( - multinom_reg(penalty = 0.1), - engine = "glmnet", + multinom_reg(penalty = 0.1) %>% set_engine("glmnet"), control = ctrl, x = iris[, 1:4], y = iris$Species @@ -64,10 +62,9 @@ test_that('glmnet prediction, one lambda', { expect_equal(uni_pred, predict(xy_fit, iris[rows, 1:4], type = "class")$.pred_class) res_form <- fit( - multinom_reg(penalty = 0.1), + multinom_reg(penalty = 0.1) %>% set_engine("glmnet"), Species ~ log(Sepal.Width) + Petal.Width, data = iris, - engine = "glmnet", control = ctrl ) @@ -93,8 +90,7 @@ test_that('glmnet probabilities, mulitiple lambda', { lams <- c(0.01, 0.1) xy_fit <- fit_xy( - multinom_reg(penalty = lams), - engine = "glmnet", + multinom_reg(penalty = lams) %>% set_engine("glmnet"), control = ctrl, x = iris[, 1:4], y = iris$Species diff --git a/tests/testthat/test_multinom_reg_spark.R b/tests/testthat/test_multinom_reg_spark.R index 0b3f15206..0954e52a0 100644 --- a/tests/testthat/test_multinom_reg_spark.R +++ b/tests/testthat/test_multinom_reg_spark.R @@ -31,8 +31,7 @@ test_that('spark execution', { expect_error( spark_class_fit <- fit( - multinom_reg(), - engine = "spark", + multinom_reg() %>% set_engine("spark"), control = ctrl, Species ~ ., data = iris_tr diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index c9defcb04..797c55d0f 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -11,7 +11,7 @@ source("helpers.R") test_that('primary arguments', { basic <- nearest_neighbor() - basic_kknn <- translate(basic, engine = "kknn") + basic_kknn <- translate(basic %>% set_engine( "kknn")) expect_equal( object = basic_kknn$method$fit$args, @@ -23,7 +23,7 @@ test_that('primary arguments', { ) neighbors <- nearest_neighbor(neighbors = 5) - neighbors_kknn <- translate(neighbors, engine = "kknn") + neighbors_kknn <- translate(neighbors %>% set_engine( "kknn")) expect_equal( object = neighbors_kknn$method$fit$args, @@ -36,7 +36,7 @@ test_that('primary arguments', { ) weight_func <- nearest_neighbor(weight_func = "triangular") - weight_func_kknn <- translate(weight_func, engine = "kknn") + weight_func_kknn <- translate(weight_func %>% set_engine( "kknn")) expect_equal( object = weight_func_kknn$method$fit$args, @@ -49,7 +49,7 @@ test_that('primary arguments', { ) dist_power <- nearest_neighbor(dist_power = 2) - dist_power_kknn <- translate(dist_power, engine = "kknn") + dist_power_kknn <- translate(dist_power %>% set_engine( "kknn")) expect_equal( object = dist_power_kknn$method$fit$args, @@ -65,7 +65,7 @@ test_that('primary arguments', { test_that('engine arguments', { - kknn_scale <- nearest_neighbor(scale = FALSE) + kknn_scale <- nearest_neighbor() %>% set_engine( "kknn", scale = FALSE) expect_equal( object = translate(kknn_scale, "kknn")$method$fit$args, @@ -82,30 +82,21 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- nearest_neighbor( scale = FALSE) - expr1_exp <- nearest_neighbor(neighbors = 5, scale = FALSE) + expr1 <- nearest_neighbor() %>% set_engine( "kknn", scale = FALSE) + expr1_exp <- nearest_neighbor(neighbors = 5) %>% set_engine( "kknn", scale = FALSE) - expr2 <- nearest_neighbor(neighbors = varying()) - expr2_exp <- nearest_neighbor(neighbors = varying(), weight_func = "triangular") + expr2 <- nearest_neighbor(neighbors = varying()) %>% set_engine( "kknn") + expr2_exp <- nearest_neighbor(neighbors = varying(), weight_func = "triangular") %>% set_engine( "kknn") - expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) - expr3_exp <- nearest_neighbor(neighbors = 3) - - expr4 <- nearest_neighbor(neighbors = 1, scale = TRUE) - expr4_exp <- nearest_neighbor(neighbors = 1, scale = TRUE, ykernel = 2) + expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) %>% set_engine( "kknn") + expr3_exp <- nearest_neighbor(neighbors = 3) %>% set_engine( "kknn") expect_equal(update(expr1, neighbors = 5), expr1_exp) expect_equal(update(expr2, weight_func = "triangular"), expr2_exp) expect_equal(update(expr3, neighbors = 3, fresh = TRUE), expr3_exp) - expect_equal(update(expr4, ykernel = 2), expr4_exp) - }) test_that('bad input', { - # expect_error(nearest_neighbor(eighbor = 7)) expect_error(nearest_neighbor(mode = "reallyunknown")) - # expect_error(nearest_neighbor(neighbors = -5)) - # expect_error(nearest_neighbor(neighbors = 5.5)) - # expect_error(nearest_neighbor(neighbors = c(5.5, 6))) - expect_warning(translate(nearest_neighbor(), engine = NULL)) + expect_error(translate(nearest_neighbor() %>% set_engine( NULL))) }) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index b94ebf535..8cdddd2dc 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -8,7 +8,8 @@ context("nearest neighbor execution with kknn") num_pred <- c("Sepal.Width", "Petal.Width", "Petal.Length") iris_bad_form <- as.formula(Species ~ term) -iris_basic <- nearest_neighbor(neighbors = 8, weight_func = "triangular") +iris_basic <- nearest_neighbor(neighbors = 8, weight_func = "triangular") %>% + set_engine("kknn") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -25,7 +26,6 @@ test_that('kknn execution', { expect_error( fit_xy( iris_basic, - engine = "kknn", control = ctrl, x = iris[, num_pred], y = iris$Sepal.Length @@ -38,7 +38,6 @@ test_that('kknn execution', { expect_error( fit_xy( iris_basic, - engine = "kknn", control = ctrl, x = iris[, c("Sepal.Length", "Petal.Width")], y = iris$Species @@ -51,7 +50,7 @@ test_that('kknn execution', { iris_basic, iris_bad_form, data = iris, - engine = "kknn", + control = ctrl ) ) @@ -65,7 +64,6 @@ test_that('kknn prediction', { # continuous res_xy <- fit_xy( iris_basic, - engine = "kknn", control = ctrl, x = iris[, num_pred], y = iris$Sepal.Length @@ -81,7 +79,6 @@ test_that('kknn prediction', { # nominal res_xy_nom <- fit_xy( iris_basic, - engine = "kknn", control = ctrl, x = iris[, c("Sepal.Length", "Petal.Width")], y = iris$Species @@ -94,12 +91,12 @@ test_that('kknn prediction', { expect_equal(uni_pred_nom, predict_class(res_xy_nom, iris[1:5, c("Sepal.Length", "Petal.Width")])) + library(kknn) # see https://github.com/KlausVigo/kknn/issues/16 # continuous - formula interface res_form <- fit( iris_basic, Sepal.Length ~ log(Sepal.Width) + Species, data = iris, - engine = "kknn", control = ctrl ) diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index 3101b6c79..4b4c24d77 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -8,21 +8,24 @@ context("check predict output structures") lm_fit <- linear_reg(mode = "regression") %>% - fit(Sepal.Length ~ ., data = iris, engine = "lm") + set_engine("lm") %>% + fit(Sepal.Length ~ ., data = iris) class_dat <- airquality[complete.cases(airquality),] class_dat$Ozone <- factor(ifelse(class_dat$Ozone >= 31, "high", "low")) lr_fit <- logistic_reg() %>% - fit(Ozone ~ ., data = class_dat, engine = "glm") + set_engine("glm") %>% + fit(Ozone ~ ., data = class_dat) class_dat2 <- airquality[complete.cases(airquality),] class_dat2$Ozone <- factor(ifelse(class_dat2$Ozone >= 31, "high+values", "2low")) lr_fit_2 <- logistic_reg() %>% - fit(Ozone ~ ., data = class_dat2, engine = "glm") + set_engine("glm") %>% + fit(Ozone ~ ., data = class_dat2) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index 49f9a902e..e7996d3d5 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -11,9 +11,9 @@ source("helpers.R") test_that('primary arguments', { mtry <- rand_forest(mode = "regression", mtry = 4) - mtry_ranger <- translate(mtry, engine = "ranger") - mtry_randomForest <- translate(mtry, engine = "randomForest") - mtry_spark <- translate(mtry, engine = "spark") + mtry_ranger <- translate(mtry %>% set_engine("ranger")) + mtry_randomForest <- translate(mtry %>% set_engine("randomForest")) + mtry_spark <- translate(mtry %>% set_engine("spark")) expect_equal(mtry_ranger$method$fit$args, list( formula = expr(missing_arg()), @@ -42,9 +42,9 @@ test_that('primary arguments', { ) ) trees <- rand_forest(mode = "classification", trees = 1000) - trees_ranger <- translate(trees, engine = "ranger") - trees_randomForest <- translate(trees, engine = "randomForest") - trees_spark <- translate(trees, engine = "spark") + trees_ranger <- translate(trees %>% set_engine("ranger")) + trees_randomForest <- translate(trees %>% set_engine("randomForest")) + trees_spark <- translate(trees %>% set_engine("spark")) expect_equal(trees_ranger$method$fit$args, list( formula = expr(missing_arg()), @@ -75,9 +75,9 @@ test_that('primary arguments', { ) min_n <- rand_forest(mode = "regression", min_n = 5) - min_n_ranger <- translate(min_n, engine = "ranger") - min_n_randomForest <- translate(min_n, engine = "randomForest") - min_n_spark <- translate(min_n, engine = "spark") + min_n_ranger <- translate(min_n %>% set_engine("ranger")) + min_n_randomForest <- translate(min_n %>% set_engine("randomForest")) + min_n_spark <- translate(min_n %>% set_engine("spark")) expect_equal(min_n_ranger$method$fit$args, list( formula = expr(missing_arg()), @@ -107,9 +107,9 @@ test_that('primary arguments', { ) mtry_v <- rand_forest(mode = "classification", mtry = varying()) - mtry_v_ranger <- translate(mtry_v, engine = "ranger") - mtry_v_randomForest <- translate(mtry_v, engine = "randomForest") - mtry_v_spark <- translate(mtry_v, engine = "spark") + mtry_v_ranger <- translate(mtry_v %>% set_engine("ranger")) + mtry_v_randomForest <- translate(mtry_v %>% set_engine("randomForest")) + mtry_v_spark <- translate(mtry_v %>% set_engine("spark")) expect_equal(mtry_v_ranger$method$fit$args, list( formula = expr(missing_arg()), @@ -140,9 +140,9 @@ test_that('primary arguments', { ) trees_v <- rand_forest(mode = "regression", trees = varying()) - trees_v_ranger <- translate(trees_v, engine = "ranger") - trees_v_randomForest <- translate(trees_v, engine = "randomForest") - trees_v_spark <- translate(trees_v, engine = "spark") + trees_v_ranger <- translate(trees_v %>% set_engine("ranger")) + trees_v_randomForest <- translate(trees_v %>% set_engine("randomForest")) + trees_v_spark <- translate(trees_v %>% set_engine("spark")) expect_equal(trees_v_ranger$method$fit$args, list( formula = expr(missing_arg()), @@ -172,9 +172,9 @@ test_that('primary arguments', { ) min_n_v <- rand_forest(mode = "classification", min_n = varying()) - min_n_v_ranger <- translate(min_n_v, engine = "ranger") - min_n_v_randomForest <- translate(min_n_v, engine = "randomForest") - min_n_v_spark <- translate(min_n_v, engine = "spark") + min_n_v_ranger <- translate(min_n_v %>% set_engine("ranger")) + min_n_v_randomForest <- translate(min_n_v %>% set_engine("randomForest")) + min_n_v_spark <- translate(min_n_v %>% set_engine("spark")) expect_equal(min_n_v_ranger$method$fit$args, list( formula = expr(missing_arg()), @@ -207,8 +207,8 @@ test_that('primary arguments', { }) test_that('engine arguments', { - ranger_imp <- rand_forest(mode = "classification", importance = "impurity") - expect_equal(translate(ranger_imp, engine = "ranger")$method$fit$args, + ranger_imp <- rand_forest(mode = "classification") + expect_equal(translate(ranger_imp %>% set_engine("ranger", importance = "impurity"))$method$fit$args, list( formula = expr(missing_arg()), data = expr(missing_arg()), @@ -221,8 +221,8 @@ test_that('engine arguments', { ) ) - randomForest_votes <- rand_forest(mode = "regression", norm.votes = FALSE) - expect_equal(translate(randomForest_votes, engine = "randomForest")$method$fit$args, + randomForest_votes <- rand_forest(mode = "regression") + expect_equal(translate(randomForest_votes %>% set_engine("randomForest", norm.votes = FALSE))$method$fit$args, list( x = expr(missing_arg()), y = expr(missing_arg()), @@ -230,8 +230,8 @@ test_that('engine arguments', { ) ) - spark_gain <- rand_forest(mode = "regression", min_info_gain = 2) - expect_equal(translate(spark_gain, engine = "spark")$method$fit$args, + spark_gain <- rand_forest(mode = "regression") + expect_equal(translate(spark_gain %>% set_engine("spark", min_info_gain = 2))$method$fit$args, list( x = expr(missing_arg()), formula = expr(missing_arg()), @@ -241,78 +241,88 @@ test_that('engine arguments', { ) ) - ranger_samp_frac <- rand_forest(mode = "regression", sample.fraction = varying()) - expect_equal(translate(ranger_samp_frac, engine = "ranger")$method$fit$args, - list( - formula = expr(missing_arg()), - data = expr(missing_arg()), - case.weights = expr(missing_arg()), - sample.fraction = new_empty_quosure(varying()), - num.threads = 1, - verbose = FALSE, - seed = expr(sample.int(10^5, 1)) - ) + ranger_samp_frac <- rand_forest(mode = "regression") + expect_equal( + translate(ranger_samp_frac %>% + set_engine("ranger", sample.fraction = varying()))$method$fit$args, + list( + formula = expr(missing_arg()), + data = expr(missing_arg()), + case.weights = expr(missing_arg()), + sample.fraction = new_empty_quosure(varying()), + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10^5, 1)) + ) ) randomForest_votes_v <- - rand_forest(mode = "regression", norm.votes = FALSE, sampsize = varying()) - expect_equal(translate(randomForest_votes_v, engine = "randomForest")$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - norm.votes = new_empty_quosure(FALSE), - sampsize = new_empty_quosure(varying()) - ) + rand_forest(mode = "regression") + expect_equal( + translate(randomForest_votes_v %>% + set_engine("randomForest", norm.votes = FALSE, sampsize = varying()))$method$fit$args, + list( + x = expr(missing_arg()), + y = expr(missing_arg()), + norm.votes = new_empty_quosure(FALSE), + sampsize = new_empty_quosure(varying()) + ) ) spark_bins_v <- - rand_forest(mode = "regression", uid = "id label", max_bins = varying()) - expect_equal(translate(spark_bins_v, engine = "spark")$method$fit$args, - list( - x = expr(missing_arg()), - formula = expr(missing_arg()), - type = "regression", - uid = new_empty_quosure("id label"), - max_bins = new_empty_quosure(varying()), - seed = expr(sample.int(10^5, 1)) - ) + rand_forest(mode = "regression") + expect_equal( + translate(spark_bins_v %>% + set_engine("spark", uid = "id label", max_bins = varying()))$method$fit$args, + list( + x = expr(missing_arg()), + formula = expr(missing_arg()), + type = "regression", + uid = new_empty_quosure("id label"), + max_bins = new_empty_quosure(varying()), + seed = expr(sample.int(10^5, 1)) + ) ) }) test_that('updating', { - expr1 <- rand_forest(mode = "regression", norm.votes = FALSE, sampsize = varying()) - expr1_exp <- rand_forest(mode = "regression", mtry = 2, norm.votes = FALSE, sampsize = varying()) + expr1 <- rand_forest(mode = "regression") %>% + set_engine("randomForest", norm.votes = FALSE, sampsize = varying()) + expr1_exp <- rand_forest(mode = "regression", mtry = 2) %>% + set_engine("randomForest", norm.votes = FALSE, sampsize = varying()) - expr2 <- rand_forest(mode = "regression", mtry = 7, min_n = varying()) - expr2_exp <- rand_forest(mode = "regression", mtry = 7, min_n = varying(), norm.votes = FALSE) + expr2 <- rand_forest(mode = "regression", mtry = 7, min_n = varying()) %>% + set_engine("randomForest") + expr2_exp <- rand_forest(mode = "regression", mtry = 7, min_n = varying() %>% + set_engine("randomForest", norm.votes = FALSE)) - expr3 <- rand_forest(mode = "regression", mtry = 7, min_n = varying()) - expr3_exp <- rand_forest(mode = "regression", mtry = 2) + expr3 <- rand_forest(mode = "regression", mtry = 7, min_n = varying()) %>% + set_engine("randomForest") + expr3_exp <- rand_forest(mode = "regression", mtry = 2) %>% + set_engine("randomForest") - expr4 <- rand_forest(mode = "regression", mtry = 2, norm.votes = FALSE, sampsize = varying()) - expr4_exp <- rand_forest(mode = "regression", mtry = 2, norm.votes = TRUE, sampsize = varying()) + expr4 <- rand_forest(mode = "regression", mtry = 2) %>% + set_engine("randomForest", norm.votes = FALSE, sampsize = varying()) + expr4_exp <- rand_forest(mode = "regression", mtry = 2) %>% + set_engine("randomForest", norm.votes = TRUE, sampsize = varying()) - expr5 <- rand_forest(mode = "regression", mtry = 2, norm.votes = FALSE) - expr5_exp <- rand_forest(mode = "regression", mtry = 2, norm.votes = TRUE, sampsize = varying()) + expr5 <- rand_forest(mode = "regression", mtry = 2) %>% + set_engine("randomForest", norm.votes = FALSE) + expr5_exp <- rand_forest(mode = "regression", mtry = 2) %>% + set_engine("randomForest", norm.votes = TRUE, sampsize = varying()) expect_equal(update(expr1, mtry = 2), expr1_exp) - expect_equal(update(expr2, norm.votes = FALSE), expr2_exp) expect_equal(update(expr3, mtry = 2, fresh = TRUE), expr3_exp) - expect_equal(update(expr4, norm.votes = TRUE), expr4_exp) - expect_equal(update(expr5, norm.votes = TRUE, sampsize = varying()), expr5_exp) }) test_that('bad input', { expect_error(rand_forest(mode = "time series")) - expect_error(translate(rand_forest(mode = "classification"), engine = "wat?")) - expect_warning(translate(rand_forest(mode = "classification"), engine = NULL)) + expect_error(translate(rand_forest(mode = "classification") %>% set_engine("wat?"))) + expect_error(translate(rand_forest(mode = "classification") %>% set_engine(NULL))) expect_error(translate(rand_forest(mode = "classification", ytest = 2))) - expect_error(translate(rand_forest(mode = "regression", formula = y ~ x))) - expect_error(translate(rand_forest(mode = "classification", x = x, y = y)), engine = "randomForest") - expect_error(translate(rand_forest(mode = "regression", formula = y ~ x)), engine = "") }) diff --git a/tests/testthat/test_rand_forest_randomForest.R b/tests/testthat/test_rand_forest_randomForest.R index 56c95a367..9ae446edc 100644 --- a/tests/testthat/test_rand_forest_randomForest.R +++ b/tests/testthat/test_rand_forest_randomForest.R @@ -12,8 +12,10 @@ data("lending_club") lending_club <- head(lending_club, 200) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") -lc_basic <- rand_forest(mode = "classification") -bad_rf_cls <- rand_forest(mode = "classification", sampsize = -10) +lc_basic <- rand_forest(mode = "classification") %>% + set_engine("randomForest") +bad_rf_cls <- rand_forest(mode = "classification") %>% + set_engine("randomForest", sampsize = -10) ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -31,7 +33,6 @@ test_that('randomForest classification execution', { # lc_basic, # Class ~ funded_amnt + term, # data = lending_club, - # engine = "randomForest", # control = ctrl # ), # regexp = NA @@ -40,7 +41,6 @@ test_that('randomForest classification execution', { expect_error( fit_xy( lc_basic, - engine = "randomForest", control = ctrl, x = lending_club[, num_pred], y = lending_club$Class @@ -53,7 +53,6 @@ test_that('randomForest classification execution', { bad_rf_cls, funded_amnt ~ term, data = lending_club, - engine = "randomForest", control = ctrl ) ) @@ -63,7 +62,6 @@ test_that('randomForest classification execution', { # bad_rf_cls, # funded_amnt ~ term, # data = lending_club, - # engine = "randomForest", # control = caught_ctrl # ) # expect_true(inherits(randomForest_form_catch$fit, "try-error")) @@ -72,7 +70,6 @@ test_that('randomForest classification execution', { bad_rf_cls, x = lending_club[, num_pred], y = lending_club$total_bal_il, - engine = "randomForest", control = caught_ctrl ) expect_true(inherits(randomForest_xy_catch$fit, "try-error")) @@ -88,7 +85,6 @@ test_that('randomForest classification prediction', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "randomForest", control = ctrl ) @@ -100,7 +96,6 @@ test_that('randomForest classification prediction', { lc_basic, Class ~ funded_amnt + int_rate, data = lending_club, - engine = "randomForest", control = ctrl ) @@ -117,7 +112,6 @@ test_that('randomForest classification probabilities', { lc_basic, x = lending_club[, num_pred], y = lending_club$Class, - engine = "randomForest", control = ctrl ) @@ -132,7 +126,6 @@ test_that('randomForest classification probabilities', { lc_basic, Class ~ funded_amnt + int_rate, data = lending_club, - engine = "randomForest", control = ctrl ) @@ -147,10 +140,12 @@ test_that('randomForest classification probabilities', { car_form <- as.formula(mpg ~ .) num_pred <- names(mtcars)[3:6] -car_basic <- rand_forest(mode = "regression") +car_basic <- rand_forest(mode = "regression") %>% set_engine("randomForest") -bad_ranger_reg <- rand_forest(mode = "regression", min.node.size = -10) -bad_rf_reg <- rand_forest(mode = "regression", sampsize = -10) +bad_ranger_reg <- rand_forest(mode = "regression") %>% + set_engine("randomForest", min.node.size = -10) +bad_rf_reg <- rand_forest(mode = "regression") %>% + set_engine("randomForest", sampsize = -10) ctrl <- list(verbosity = 1, catch = FALSE) caught_ctrl <- list(verbosity = 1, catch = TRUE) @@ -167,7 +162,6 @@ test_that('randomForest regression execution', { car_basic, car_form, data = mtcars, - engine = "randomForest", control = ctrl ), regexp = NA @@ -178,7 +172,6 @@ test_that('randomForest regression execution', { car_basic, x = mtcars, y = mtcars$mpg, - engine = "randomForest", control = ctrl ), regexp = NA @@ -188,7 +181,6 @@ test_that('randomForest regression execution', { bad_rf_reg, car_form, data = mtcars, - engine = "randomForest", control = caught_ctrl ) expect_true(inherits(randomForest_form_catch$fit, "try-error")) @@ -197,7 +189,6 @@ test_that('randomForest regression execution', { bad_rf_reg, x = mtcars, y = mtcars$mpg, - engine = "randomForest", control = caught_ctrl ) expect_true(inherits(randomForest_xy_catch$fit, "try-error")) @@ -212,7 +203,6 @@ test_that('randomForest regression prediction', { car_basic, x = mtcars, y = mtcars$mpg, - engine = "randomForest", control = ctrl ) diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R index 054233d02..7e90bbc43 100644 --- a/tests/testthat/test_rand_forest_ranger.R +++ b/tests/testthat/test_rand_forest_ranger.R @@ -13,11 +13,11 @@ data("lending_club") lending_club <- head(lending_club, 200) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") -lc_basic <- rand_forest() -lc_ranger <- rand_forest(seed = 144) +lc_basic <- rand_forest() %>% set_engine("ranger") +lc_ranger <- rand_forest() %>% set_engine("ranger", seed = 144) -bad_ranger_cls <- rand_forest(replace = "bad") -bad_rf_cls <- rand_forest(sampsize = -10) +bad_ranger_cls <- rand_forest() %>% set_engine("ranger", replace = "bad") +bad_rf_cls <- rand_forest() %>% set_engine("ranger", sampsize = -10) ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -34,7 +34,7 @@ test_that('ranger classification execution', { lc_ranger, Class ~ funded_amnt + term, data = lending_club, - engine = "ranger", + control = ctrl ), regexp = NA @@ -45,7 +45,7 @@ test_that('ranger classification execution', { lc_ranger, x = lending_club[, num_pred], y = lending_club$Class, - engine = "ranger", + control = ctrl ), regexp = NA @@ -56,7 +56,7 @@ test_that('ranger classification execution', { bad_ranger_cls, funded_amnt ~ term, data = lending_club, - engine = "ranger", + control = ctrl ) ) @@ -65,14 +65,14 @@ test_that('ranger classification execution', { bad_ranger_cls, funded_amnt ~ term, data = lending_club, - engine = "ranger", + control = caught_ctrl ) expect_true(inherits(ranger_form_catch$fit, "try-error")) ranger_xy_catch <- fit_xy( bad_ranger_cls, - engine = "ranger", + control = caught_ctrl, x = lending_club[, num_pred], y = lending_club$total_bal_il @@ -86,10 +86,10 @@ test_that('ranger classification prediction', { skip_if_not_installed("ranger") xy_fit <- fit_xy( - rand_forest(), + rand_forest() %>% set_engine("ranger"), x = lending_club[, num_pred], y = lending_club$Class, - engine = "ranger", + control = ctrl ) @@ -99,10 +99,10 @@ test_that('ranger classification prediction', { expect_equal(xy_pred, predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) form_fit <- fit( - rand_forest(), + rand_forest() %>% set_engine("ranger"), Class ~ funded_amnt + int_rate, data = lending_club, - engine = "ranger", + control = ctrl ) @@ -119,10 +119,10 @@ test_that('ranger classification probabilities', { skip_if_not_installed("ranger") xy_fit <- fit_xy( - rand_forest(seed = 3566), + rand_forest() %>% set_engine("ranger", seed = 3566), x = lending_club[, num_pred], y = lending_club$Class, - engine = "ranger", + control = ctrl ) @@ -134,10 +134,10 @@ test_that('ranger classification probabilities', { expect_equivalent(xy_pred[1,], one_row) form_fit <- fit( - rand_forest(seed = 3566), + rand_forest() %>% set_engine("ranger", seed = 3566), Class ~ funded_amnt + int_rate, data = lending_club, - engine = "ranger", + control = ctrl ) @@ -146,10 +146,10 @@ test_that('ranger classification probabilities', { expect_equal(form_pred, predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) no_prob_model <- fit_xy( - rand_forest(probability = FALSE), + rand_forest() %>% set_engine("ranger", probability = FALSE), x = lending_club[, num_pred], y = lending_club$Class, - engine = "ranger", + control = ctrl ) @@ -162,10 +162,10 @@ test_that('ranger classification probabilities', { num_pred <- names(mtcars)[3:6] -car_basic <- rand_forest() +car_basic <- rand_forest() %>% set_engine("ranger") -bad_ranger_reg <- rand_forest(replace = "bad") -bad_rf_reg <- rand_forest(sampsize = -10) +bad_ranger_reg <- rand_forest() %>% set_engine("ranger", replace = "bad") +bad_rf_reg <- rand_forest() %>% set_engine("ranger", sampsize = -10) ctrl <- list(verbosity = 1, catch = FALSE) caught_ctrl <- list(verbosity = 1, catch = TRUE) @@ -182,7 +182,6 @@ test_that('ranger regression execution', { car_basic, mpg ~ ., data = mtcars, - engine = "ranger", control = ctrl ), regexp = NA @@ -193,7 +192,6 @@ test_that('ranger regression execution', { car_basic, x = mtcars, y = mtcars$mpg, - engine = "ranger", control = ctrl ), regexp = NA @@ -204,14 +202,12 @@ test_that('ranger regression execution', { bad_ranger_reg, mpg ~ ., data = mtcars, - engine = "ranger", control = caught_ctrl ) expect_true(inherits(ranger_form_catch$fit, "try-error")) ranger_xy_catch <- fit_xy( bad_ranger_reg, - engine = "ranger", control = caught_ctrl, x = mtcars[, num_pred], y = mtcars$mpg @@ -228,7 +224,6 @@ test_that('ranger regression prediction', { car_basic, x = mtcars[, -1], y = mtcars$mpg, - engine = "ranger", control = ctrl ) @@ -244,10 +239,9 @@ test_that('ranger regression intervals', { skip_if_not_installed("ranger") xy_fit <- fit_xy( - rand_forest(keep.inbag = TRUE), + rand_forest() %>% set_engine("ranger", keep.inbag = TRUE), x = mtcars[, -1], y = mtcars$mpg, - engine = "ranger", control = ctrl ) @@ -276,35 +270,31 @@ test_that('additional descriptor tests', { skip_if_not_installed("ranger") descr_xy <- fit_xy( - rand_forest(mtry = floor(sqrt(.cols())) + 1), + rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), x = mtcars[, -1], y = mtcars$mpg, - engine = "ranger", control = ctrl ) expect_equal(descr_xy$fit$mtry, 4) descr_f <- fit( - rand_forest(mtry = floor(sqrt(.cols())) + 1), + rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), mpg ~ ., data = mtcars, - engine = "ranger", control = ctrl ) expect_equal(descr_f$fit$mtry, 4) descr_xy <- fit_xy( - rand_forest(mtry = floor(sqrt(.cols())) + 1), + rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), x = mtcars[, -1], y = mtcars$mpg, - engine = "ranger", control = ctrl ) expect_equal(descr_xy$fit$mtry, 4) descr_f <- fit( - rand_forest(mtry = floor(sqrt(.cols())) + 1), + rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), mpg ~ ., data = mtcars, - engine = "ranger", control = ctrl ) expect_equal(descr_f$fit$mtry, 4) @@ -314,50 +304,38 @@ test_that('additional descriptor tests', { exp_wts <- quo(c(min(.lvls()), 20, 10)) descr_other_xy <- fit_xy( - rand_forest( - mtry = 2, - class.weights = c(min(.lvls()), 20, 10) - ), + rand_forest(mtry = 2) %>% + set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), x = iris[, 1:4], y = iris$Species, - engine = "ranger", control = ctrl ) expect_equal(descr_other_xy$fit$mtry, 2) expect_equal(descr_other_xy$fit$call$class.weights, exp_wts) descr_other_f <- fit( - rand_forest( - mtry = 2, - class.weights = c(min(.lvls()), 20, 10) - ), + rand_forest(mtry = 2) %>% + set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), Species ~ ., data = iris, - engine = "ranger", control = ctrl ) expect_equal(descr_other_f$fit$mtry, 2) expect_equal(descr_other_f$fit$call$class.weights, exp_wts) descr_other_xy <- fit_xy( - rand_forest( - mtry = 2, - class.weights = c(min(.lvls()), 20, 10) - ), + rand_forest(mtry = 2) %>% + set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), x = iris[, 1:4], y = iris$Species, - engine = "ranger", control = ctrl ) expect_equal(descr_other_xy$fit$mtry, 2) expect_equal(descr_other_xy$fit$call$class.weights, exp_wts) descr_other_f <- fit( - rand_forest( - mtry = 2, - class.weights = c(min(.lvls()), 20, 10) - ), + rand_forest(mtry = 2) %>% + set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), Species ~ ., data = iris, - engine = "ranger", control = ctrl ) expect_equal(descr_other_f$fit$mtry, 2) @@ -370,11 +348,10 @@ test_that('ranger classification prediction', { skip_if_not_installed("ranger") xy_class_fit <- - rand_forest() %>% + rand_forest() %>% set_engine("ranger") %>% fit_xy( x = iris[, 1:4], y = iris$Species, - engine = "ranger", control = ctrl ) @@ -389,10 +366,10 @@ test_that('ranger classification prediction', { xy_prob_fit <- rand_forest() %>% + set_engine("ranger") %>% fit_xy( x = iris[, 1:4], y = iris$Species, - engine = "ranger", control = ctrl ) @@ -420,10 +397,10 @@ test_that('ranger classification intervals', { skip_if_not_installed("ranger") lc_fit <- fit( - rand_forest(keep.inbag = TRUE, probability = TRUE), + rand_forest() %>% + set_engine("ranger", keep.inbag = TRUE, probability = TRUE), Class ~ funded_amnt + int_rate, data = lending_club, - engine = "ranger", control = ctrl ) diff --git a/tests/testthat/test_rand_forest_spark.R b/tests/testthat/test_rand_forest_spark.R index 81ff73d0e..bcda93626 100644 --- a/tests/testthat/test_rand_forest_spark.R +++ b/tests/testthat/test_rand_forest_spark.R @@ -30,12 +30,8 @@ test_that('spark execution', { expect_error( spark_reg_fit <- fit( - rand_forest( - trees = 5, - mode = "regression", - seed = 12 - ), - engine = "spark", + rand_forest(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), control = ctrl, Sepal_Length ~ ., data = iris_rf_tr @@ -47,12 +43,8 @@ test_that('spark execution', { expect_error( spark_reg_fit_dup <- fit( - rand_forest( - trees = 5, - mode = "regression", - seed = 12 - ), - engine = "spark", + rand_forest(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), control = ctrl, Sepal_Length ~ ., data = iris_rf_tr @@ -104,12 +96,8 @@ test_that('spark execution', { expect_error( spark_class_fit <- fit( - rand_forest( - trees = 5, - mode = "classification", - seed = 12 - ), - engine = "spark", + rand_forest(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), control = ctrl, churn ~ ., data = churn_rf_tr @@ -121,12 +109,8 @@ test_that('spark execution', { expect_error( spark_class_fit_dup <- fit( - rand_forest( - trees = 5, - mode = "classification", - seed = 12 - ), - engine = "spark", + rand_forest(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), control = ctrl, churn ~ ., data = churn_rf_tr @@ -186,7 +170,7 @@ test_that('spark execution', { regexp = NA ) - expect_equal(colnames(spark_class_prob), c("pred_0", "pred_1")) + expect_equal(colnames(spark_class_prob), c("pred_No", "pred_Yes")) expect_equivalent( as.data.frame(spark_class_prob), diff --git a/tests/testthat/test_surv_reg.R b/tests/testthat/test_surv_reg.R index e67f6ce42..bc3888d13 100644 --- a/tests/testthat/test_surv_reg.R +++ b/tests/testthat/test_surv_reg.R @@ -12,7 +12,7 @@ source("helpers.R") test_that('primary arguments', { basic <- surv_reg() - basic_flexsurv <- translate(basic, engine = "flexsurv") + basic_flexsurv <- translate(basic %>% set_engine("flexsurv")) expect_equal(basic_flexsurv$method$fit$args, list( @@ -23,7 +23,7 @@ test_that('primary arguments', { ) normal <- surv_reg(dist = "lnorm") - normal_flexsurv <- translate(normal, engine = "flexsurv") + normal_flexsurv <- translate(normal %>% set_engine("flexsurv")) expect_equal(normal_flexsurv$method$fit$args, list( formula = expr(missing_arg()), @@ -34,7 +34,7 @@ test_that('primary arguments', { ) dist_v <- surv_reg(dist = varying()) - dist_v_flexsurv <- translate(dist_v, engine = "flexsurv") + dist_v_flexsurv <- translate(dist_v %>% set_engine("flexsurv")) expect_equal(dist_v_flexsurv$method$fit$args, list( formula = expr(missing_arg()), @@ -46,8 +46,8 @@ test_that('primary arguments', { }) test_that('engine arguments', { - fs_cl <- surv_reg(cl = .99) - expect_equal(translate(fs_cl, engine = "flexsurv")$method$fit$args, + fs_cl <- surv_reg() + expect_equal(translate(fs_cl %>% set_engine("flexsurv", cl = .99))$method$fit$args, list( formula = expr(missing_arg()), data = expr(missing_arg()), @@ -60,20 +60,13 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- surv_reg( cl = .99) - expr1_exp <- surv_reg(dist = "lnorm", cl = .99) - - expr2 <- surv_reg(dist = varying()) - expr2_exp <- surv_reg(dist = varying(), cl = .99) - + expr1 <- surv_reg() %>% set_engine("flexsurv", cl = .99) + expr1_exp <- surv_reg(dist = "lnorm") %>% set_engine("flexsurv", cl = .99) expect_equal(update(expr1, dist = "lnorm"), expr1_exp) - expect_equal(update(expr2, cl = .99), expr2_exp) }) test_that('bad input', { - expect_error(surv_reg(mode = "classification")) - expect_error(translate(surv_reg(), engine = "wat?")) - expect_warning(translate(surv_reg(), engine = NULL)) - expect_error(translate(surv_reg(formula = y ~ x))) - expect_warning(translate(surv_reg(formula = y ~ x), engine = "flexsurv")) + expect_error(surv_reg(mode = ", classification")) + expect_error(translate(surv_reg() %>% set_engine("wat"))) + expect_error(translate(surv_reg() %>% set_engine(NULL))) }) diff --git a/tests/testthat/test_surv_reg_flexsurv.R b/tests/testthat/test_surv_reg_flexsurv.R index 6e0ad9944..c8cab6687 100644 --- a/tests/testthat/test_surv_reg_flexsurv.R +++ b/tests/testthat/test_surv_reg_flexsurv.R @@ -2,13 +2,14 @@ library(testthat) library(parsnip) library(rlang) library(survival) +library(tibble) # ------------------------------------------------------------------------------ basic_form <- Surv(recyrs, censrec) ~ group complete_form <- Surv(recyrs) ~ group -surv_basic <- surv_reg() +surv_basic <- surv_reg() %>% set_engine("flexsurv") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) @@ -29,8 +30,7 @@ test_that('flexsurv execution', { surv_basic, Surv(recyrs, censrec) ~ group, data = bc, - control = ctrl, - engine = "flexsurv" + control = ctrl ), regexp = NA ) @@ -39,8 +39,7 @@ test_that('flexsurv execution', { surv_basic, Surv(recyrs) ~ group, data = bc, - control = ctrl, - engine = "flexsurv" + control = ctrl ), regexp = NA ) @@ -49,7 +48,6 @@ test_that('flexsurv execution', { surv_basic, x = bc[, "group", drop = FALSE], y = bc$recyrs, - engine = "flexsurv", control = ctrl ) ) @@ -68,8 +66,7 @@ test_that('flexsurv prediction', { surv_basic, Surv(recyrs, censrec) ~ group, data = bc, - control = ctrl, - engine = "flexsurv" + control = ctrl ) exp_pred <- summary(res$fit, head(bc), type = "mean") exp_pred <- do.call("rbind", unclass(exp_pred)) diff --git a/tests/testthat/test_surv_reg_survreg.R b/tests/testthat/test_surv_reg_survreg.R index c78b1a271..5a8e0a66e 100644 --- a/tests/testthat/test_surv_reg_survreg.R +++ b/tests/testthat/test_surv_reg_survreg.R @@ -8,8 +8,8 @@ library(tibble) basic_form <- Surv(time, status) ~ group complete_form <- Surv(time) ~ group -surv_basic <- surv_reg() -surv_lnorm <- surv_reg(dist = "lognormal") +surv_basic <- surv_reg() %>% set_engine("survreg") +surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survreg") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -24,8 +24,7 @@ test_that('survival execution', { surv_basic, Surv(time, status) ~ age + sex, data = lung, - control = ctrl, - engine = "survreg" + control = ctrl ), regexp = NA ) @@ -34,8 +33,7 @@ test_that('survival execution', { surv_lnorm, Surv(time) ~ age + sex, data = lung, - control = ctrl, - engine = "survreg" + control = ctrl ), regexp = NA ) @@ -44,7 +42,6 @@ test_that('survival execution', { surv_basic, x = lung[, c("age", "sex")], y = lung$time, - engine = "survreg", control = ctrl ) ) @@ -56,8 +53,7 @@ test_that('survival prediction', { surv_basic, Surv(time, status) ~ age + sex, data = lung, - control = ctrl, - engine = "survreg" + control = ctrl ) exp_pred <- predict(res$fit, head(lung)) exp_pred <- tibble(.pred = unname(exp_pred)) diff --git a/tests/testthat/test_varying.R b/tests/testthat/test_varying.R index 78f14f17f..eb63dff44 100644 --- a/tests/testthat/test_varying.R +++ b/tests/testthat/test_varying.R @@ -41,31 +41,34 @@ test_that('main parsnip arguments', { test_that('other parsnip arguments', { other_1 <- - rand_forest(sample.fraction = varying()) %>% - varying_args(id = "only others") + rand_forest() %>% + set_engine("ranger", sample.fraction = varying()) %>% + varying_args(id = "only engine args") exp_1 <- tibble( name = c("mtry", "trees", "min_n", "sample.fraction"), varying = c(rep(FALSE, 3), TRUE), - id = rep("only others", 4), + id = rep("only engine args", 4), type = rep("model_spec", 4) ) expect_equal(other_1, exp_1) other_2 <- - rand_forest(min_n = varying(), sample.fraction = varying()) %>% - varying_args(id = "only others") + rand_forest(min_n = varying()) %>% + set_engine("ranger", sample.fraction = varying()) %>% + varying_args(id = "only engine args") exp_2 <- tibble( name = c("mtry", "trees", "min_n", "sample.fraction"), varying = c(rep(FALSE, 2), rep(TRUE, 2)), - id = rep("only others", 4), + id = rep("only engine args", 4), type = rep("model_spec", 4) ) expect_equal(other_2, exp_2) other_3 <- - rand_forest(strata = Class, sampsize = c(varying(), varying())) %>% + rand_forest() %>% + set_engine("ranger", strata = Class, sampsize = c(varying(), varying())) %>% varying_args(id = "add an expr") exp_3 <- tibble( @@ -77,7 +80,8 @@ test_that('other parsnip arguments', { expect_equal(other_3, exp_3) other_4 <- - rand_forest(strata = Class, sampsize = c(12, varying())) %>% + rand_forest() %>% + set_engine("ranger", strata = Class, sampsize = c(12, varying())) %>% varying_args(id = "num and varying in vec") exp_4 <- tibble( diff --git a/vignettes/articles/Classification.Rmd b/vignettes/articles/Classification.Rmd index 7f0f76846..bf123a73f 100644 --- a/vignettes/articles/Classification.Rmd +++ b/vignettes/articles/Classification.Rmd @@ -54,14 +54,12 @@ test_normalized <- bake(credit_rec, newdata = credit_test, all_predictors()) `keras` will be used to fit a model with 5 hidden units and uses a 10% dropout rate to regularize the model. At each training iteration (aka epoch) a random 20% of the data will be used to measure the cross-entropy of the model. ```{r credit-nnet} -library(keras) - set.seed(57974) -nnet_fit <- mlp( - epochs = 100, hidden_units = 5, dropout = 0.1, - others = list(verbose = 0, validation_split = .20) -) %>% - parsnip::fit.model_spec(Status ~ ., data = juice(credit_rec), engine = "keras") +nnet_fit <- + mlp(epochs = 100, hidden_units = 5, dropout = 0.1) %>% + # Also set engine-specific arguments: + set_engine("keras", verbose = 0, validation_split = .20) %>% + fit(Status ~ ., data = juice(credit_rec)) nnet_fit ``` diff --git a/vignettes/articles/Regression.Rmd b/vignettes/articles/Regression.Rmd index 1c137226e..21395e04c 100644 --- a/vignettes/articles/Regression.Rmd +++ b/vignettes/articles/Regression.Rmd @@ -61,11 +61,12 @@ The model will be fit with the `ranger` package. Since we didn't add any extra a ```{r rf-basic-xy} preds <- c("Longitude", "Latitude", "Lot_Area", "Neighborhood", "Year_Sold") -rf_xy_fit <- rf_defaults %>% +rf_xy_fit <- + rf_defaults %>% + set_engine("ranger") %>% fit_xy( x = ames_train[, preds], - y = log10(ames_train$Sale_Price), - engine = "ranger" + y = log10(ames_train$Sale_Price) ) rf_xy_fit ``` @@ -95,22 +96,22 @@ Now, for illustration, let's use the formula method using some new parameter val ```{r rf-basic-form} rand_forest(mode = "regression", mtry = 3, trees = 1000) %>% + set_engine("ranger") %>% fit( log10(Sale_Price) ~ Longitude + Latitude + Lot_Area + Neighborhood + Year_Sold, - data = ames_train, - engine = "ranger" + data = ames_train ) ``` -Suppose that there was some feature in the `randomForest` package that we'd like to evaluate. To do so, the only part of the syntaxt that needs to change is the `engine` argument: +Suppose that there was some feature in the `randomForest` package that we'd like to evaluate. To do so, the only part of the syntaxt that needs to change is the `set_engine` argument: ```{r rf-rf} rand_forest(mode = "regression", mtry = 3, trees = 1000) %>% + set_engine("randomForest") %>% fit( log10(Sale_Price) ~ Longitude + Latitude + Lot_Area + Neighborhood + Year_Sold, - data = ames_train, - engine = "randomForest" + data = ames_train ) ``` @@ -131,10 +132,10 @@ For example, let's use an expression with the `.preds()` descriptor to fit a bag ```{r bagged} rand_forest(mode = "regression", mtry = .preds(), trees = 1000) %>% + set_engine("ranger") %>% fit( log10(Sale_Price) ~ Longitude + Latitude + Lot_Area + Neighborhood + Year_Sold, - data = ames_train, - engine = "ranger" + data = ames_train ) ``` @@ -158,12 +159,10 @@ norm_recipe <- recipe( # Now let's fit the model using the processed version of the data -glmn_fit <- linear_reg(penalty = 0.001, mixture = 0.5) %>% - fit( - Sale_Price ~ ., - data = juice(norm_recipe), - engine = "glmnet" - ) +glmn_fit <- + linear_reg(penalty = 0.001, mixture = 0.5) %>% + set_engine("glmnet") %>% + fit(Sale_Price ~ ., data = juice(norm_recipe)) glmn_fit ``` diff --git a/vignettes/articles/Scratch.Rmd b/vignettes/articles/Scratch.Rmd index 7767d938f..22d17a63f 100644 --- a/vignettes/articles/Scratch.Rmd +++ b/vignettes/articles/Scratch.Rmd @@ -92,13 +92,13 @@ This is a fairly simple function that can follow a basic template. The main argu * The mode. If the model can do more than one mode, you might default this to "unknown". In our case, since it is only a classification model, it makes sense to default it to that mode. * The argument names (`sub_classes` here). These should be defaulted to `NULL`. - * `...` is used to pass in other arguments to the underlying model fit functions. + * `...` are _not_ used in the main model function. A basic version of the function is: ```{r model-fun} mixture_da <- - function(mode = "classification", sub_classes = NULL, ...) { + function(mode = "classification", sub_classes = NULL) { # Check for correct mode if (!(mode %in% mixture_da_modes)) stop("`mode` should be one of: ", @@ -106,14 +106,10 @@ mixture_da <- call. = FALSE) # Capture the arguments in quosures - others <- enquos(...) - args <- list(sub_classes = enquo(sub_classes)) + args <- list(sub_classes = rlang::enquo(sub_classes)) - # Save the other arguments but remove them if they are null. - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - out <- list(args = args, others = others, + # Save some empty slots for future parts of the specification + out <- list(args = args, eng_args = NULL, mode = mode, method = NULL, engine = NULL) # set classes in the correct order @@ -235,7 +231,8 @@ For example: library(tidymodels) mixture_da(sub_classes = 2) %>% - translate(engine = "mda") + set_engine("mda") %>% + translate() ``` Let's try it on the iris data: @@ -249,7 +246,8 @@ iris_test <- testing(iris_split) mda_spec <- mixture_da(sub_classes = 2) mda_fit <- mda_spec %>% - fit(Species ~ ., data = iris_train, engine = "mda") + set_engine("mda") %>% + fit(Species ~ ., data = iris_train) mda_fit predict(mda_fit, new_data = iris_test) %>% @@ -281,8 +279,9 @@ However, there are some models (e.g. `glmnet`, `plsr`, `Cubist`, etc.) that can For example, for a multinomial `glmnet` model, we leave `penalty` unspecified when fitting and get predictions on a sequence of values: ```{r mnom-glmnet-fit} -mod <- multinom_reg(mixture = 1/3) -mod_fit <- fit(mod, Species ~ ., data = iris, engine = "glmnet") +mod <- multinom_reg(mixture = 1/3) %>% + set_engine("glmnet") +mod_fit <- fit(mod, Species ~ ., data = iris) preds <- multi_predict(mod_fit, iris[1:3, -5], penalty = c(0, 0.01, 0.1), type = "prob") preds @@ -313,8 +312,9 @@ logistic_reg() %>% translate(engine = "glm") # but you can change it: -logistic_reg(family = stats::binomial(link = "probit")) %>% - translate(engine = "glm") +logistic_reg() %>% + set_engine("glm", family = stats::binomial(link = "probit")) %>% + translate() ``` That's what `defaults` are for. @@ -329,7 +329,7 @@ For example, the `ranger` and `randomForest` package functions have arguments fo ```{r rf-trans, eval = FALSE} # Simplified version -translate.rand_forest <- function (x, engine, ...){ +translate.rand_forest <- function (x, engine = x$engine, ...){ # Run the general method to get the real arguments in place x <- translate.default(x, engine, ...) @@ -337,7 +337,7 @@ translate.rand_forest <- function (x, engine, ...){ arg_vals <- x$method$fit$args # Check and see if they make sense for the engine and/or mode: - if (x$engine == "ranger") { + if (engine == "ranger") { if (any(names(arg_vals) == "importance")) # We want to check the type of `importance` but it is a quosure. We first # get the expression. It is is logical, the value of `quo_get_expr` will diff --git a/vignettes/parsnip_Intro.Rmd b/vignettes/parsnip_Intro.Rmd index 4448def91..1536be911 100644 --- a/vignettes/parsnip_Intro.Rmd +++ b/vignettes/parsnip_Intro.Rmd @@ -77,15 +77,12 @@ The arguments to the default function are: args(rand_forest) ``` -However, there might be other arguments that you would like to change or allow to vary. These are accessible using the `...` slot. This is a named list of arguments in the form of the underlying function being called. For example, `ranger` has an option to set the internal random number seed. To set this to a specific value: +However, there might be other arguments that you would like to change or allow to vary. These are accessible using `set_engine`. For example, `ranger` has an option to set the internal random number seed. To set this to a specific value: ```{r rf-seed} -rf_with_seed <- rand_forest( - trees = 2000, - mtry = varying(), - seed = 63233, - mode = "regression" -) +rf_with_seed <- + rand_forest(trees = 2000, mtry = varying(), mode = "regression") %>% + set_engine("ranger", seed = 63233) rf_with_seed ``` @@ -102,7 +99,8 @@ For example, `rf_with_seed` above is not ready for fitting due the `varying()` p ```{r, eval = FALSE} rf_with_seed %>% set_args(mtry = 4) %>% - fit(mpg ~ ., data = mtcars, engine = "ranger") + set_engine("ranger") %>% + fit(mpg ~ ., data = mtcars) ``` ``` @@ -131,7 +129,8 @@ Or, using the `randomForest` package: set.seed(56982) rf_with_seed %>% set_args(mtry = 4) %>% - fit(mpg ~ ., data = mtcars, engine = "randomForest") + set_engine("randomForest") %>% + fit(mpg ~ ., data = mtcars) ``` ```