From ee4b955ed8fd56c4ebeb9f064bb4da7336911139 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 23 Oct 2018 15:05:47 -0400 Subject: [PATCH 1/6] function and documentation changes for set_engine --- DESCRIPTION | 4 +- NAMESPACE | 1 + R/boost_tree.R | 50 ++--- R/fit.R | 26 +-- R/linear_reg.R | 30 +-- R/logistic_reg.R | 27 +-- R/mars.R | 40 +--- R/mlp.R | 42 +--- R/multinom_reg.R | 26 +-- R/nearest_neighbor.R | 31 +-- R/predict.R | 2 +- R/rand_forest.R | 42 ++-- R/set_engine.R | 31 +++ R/surv_reg.R | 35 +-- R/translate.R | 7 +- docs/articles/articles/Classification.html | 95 ++++---- docs/articles/articles/Models.html | 28 +-- docs/articles/articles/Regression.html | 138 ++++++------ docs/articles/articles/Scratch.html | 238 ++++++++++----------- docs/articles/parsnip_Intro.html | 40 ++-- docs/reference/boost_tree.html | 26 +-- docs/reference/fit.html | 45 +--- docs/reference/linear_reg.html | 21 +- docs/reference/logistic_reg.html | 20 +- docs/reference/mars.html | 24 +-- docs/reference/mlp.html | 23 +- docs/reference/model_fit.html | 24 +-- docs/reference/multinom_reg.html | 20 +- docs/reference/nearest_neighbor.html | 22 +- docs/reference/predict.model_fit.html | 63 +----- docs/reference/rand_forest.html | 22 +- docs/reference/set_engine.html | 215 +++++++++++++++++++ docs/reference/surv_reg.html | 39 ++-- docs/reference/varying_args.html | 24 +-- man/boost_tree.Rd | 23 +- man/fit.Rd | 11 +- man/linear_reg.Rd | 17 +- man/logistic_reg.Rd | 16 +- man/mars.Rd | 19 +- man/mlp.Rd | 20 +- man/multinom_reg.Rd | 16 +- man/nearest_neighbor.Rd | 15 +- man/predict.model_fit.Rd | 2 +- man/rand_forest.Rd | 19 +- man/set_engine.Rd | 33 +++ man/surv_reg.Rd | 15 +- vignettes/articles/Classification.Rmd | 12 +- vignettes/articles/Regression.Rmd | 31 ++- vignettes/articles/Scratch.Rmd | 34 +-- vignettes/parsnip_Intro.Rmd | 17 +- 50 files changed, 817 insertions(+), 1004 deletions(-) create mode 100644 R/set_engine.R create mode 100644 docs/reference/set_engine.html create mode 100644 man/set_engine.Rd diff --git a/DESCRIPTION b/DESCRIPTION index c7c333099..5215d6dc5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip -Version: 0.0.0.9004 -Title: A Common API to Modeling and analysis Functions +Version: 0.0.0.9005 +Title: A Common API to Modeling and Analysis Functions Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. R, spark, stan, etc). Authors@R: c( person("Max", "Kuhn", , "max@rstudio.com", c("aut", "cre")), diff --git a/NAMESPACE b/NAMESPACE index 88b02f030..7d4305157 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -102,6 +102,7 @@ export(predict_raw) export(predict_raw.model_fit) export(rand_forest) export(set_args) +export(set_engine) export(set_mode) export(show_call) export(surv_reg) diff --git a/R/boost_tree.R b/R/boost_tree.R index 61f2d0f0a..b058ecb1b 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -22,7 +22,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using the `set_engine` function. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -46,11 +46,6 @@ #' @param sample_size An number for the number (or proportion) of data that is #' exposed to the fitting routine. For `xgboost`, the sampling is done at at #' each iteration while `C5.0` samples once during traning. -#' @param ... Other arguments to pass to the specific engine's -#' model fit function (see the Engine Details section below). This -#' should not include arguments defined by the main parameters to -#' this function. For the `update` function, the ellipses can -#' contain the primary arguments or any others. #' @details #' The data given to the function are not saved and are only used #' to determine the _mode_ of the model. For `boost_tree`, the @@ -63,17 +58,12 @@ #' \item \pkg{Spark}: `"spark"` #' } #' -#' Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`). -#' #' #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of -#' model, the template of the fit calls are: +#' model fit call. For this type of model, the template of the +#' fit calls are: #' #' \pkg{xgboost} classification #' @@ -109,7 +99,7 @@ #' reloaded and reattached to the `parsnip` object. #' #' @importFrom purrr map_lgl -#' @seealso [varying()], [fit()] +#' @seealso [varying()], [fit()], [set_engine()] #' @examples #' boost_tree(mode = "classification", trees = 20) #' # Parameters can be represented by a placeholder: @@ -121,11 +111,7 @@ boost_tree <- mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, - sample_size = NULL, - ...) { - - others <- enquos(...) - + sample_size = NULL) { args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -141,10 +127,7 @@ boost_tree <- paste0("'", boost_tree_modes, "'", collapse = ", "), call. = FALSE) - no_value <- !vapply(others, null_value, logical(1)) - others <- others[no_value] - - out <- list(args = args, others = others, + out <- list(args = args, others = NULL, mode = mode, method = NULL, engine = NULL) class(out) <- make_classes("boost_tree") out @@ -183,11 +166,7 @@ update.boost_tree <- mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, sample_size = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) - + fresh = FALSE) { args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -209,23 +188,20 @@ update.boost_tree <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } # ------------------------------------------------------------------------------ #' @export -translate.boost_tree <- function(x, engine, ...) { +translate.boost_tree <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'xgboost'` for translation.") + engine <- "xgboost" + } x <- translate.default(x, engine, ...) - if (x$engine == "spark") { + if (engine == "spark") { if (x$mode == "unknown") stop( "For spark boosted trees models, the mode cannot be 'unknown' ", diff --git a/R/fit.R b/R/fit.R index 4f240545a..bd1c89c0f 100644 --- a/R/fit.R +++ b/R/fit.R @@ -17,15 +17,11 @@ #' below). A data frame containing all relevant variables (e.g. #' outcome(s), predictors, case weights, etc). Note: when needed, a #' \emph{named argument} should be used. -#' @param engine A character string for the software that should -#' be used to fit the model. This is highly dependent on the type -#' of model (e.g. linear regression, random forest, etc.). #' @param control A named list with elements `verbosity` and #' `catch`. See [fit_control()]. #' @param ... Not currently used; values passed here will be #' ignored. Other options required to fit the model should be -#' passed using the `others` argument in the original model -#' specification. +#' passed using `set_engine`. #' @details `fit` and `fit_xy` substitute the current arguments in the model #' specification into the computational engine's code, checks them #' for validity, then fits the model using the data and the @@ -92,11 +88,13 @@ fit.model_spec <- function(object, formula = NULL, data = NULL, - engine = object$engine, control = fit_control(), ... ) { dots <- quos(...) + if (any(names(dots) == "engine")) + stop("Use `set_engine` to supply the engine.", call. = FALSE) + if (all(c("x", "y") %in% names(dots))) stop("`fit.model_spec` is for the formula methods. Use `fit_xy` instead.", call. = FALSE) @@ -109,10 +107,8 @@ fit.model_spec <- eval_env$formula <- formula fit_interface <- check_interface(eval_env$formula, eval_env$data, cl, object) - object$engine <- engine - object <- check_engine(object) - if (engine == "spark" && !inherits(eval_env$data, "tbl_spark")) + if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) stop( "spark objects can only be used with the formula interface to `fit` ", "with a spark data object.", call. = FALSE @@ -122,7 +118,7 @@ fit.model_spec <- object <- get_method(object, engine = object$engine) check_installs(object) # TODO rewrite with pkgman - # TODO Should probably just load the namespace + load_libs(object, control$verbosity < 2) interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") @@ -178,20 +174,20 @@ fit_xy.model_spec <- function(object, x = NULL, y = NULL, - engine = object$engine, control = fit_control(), ... ) { + dots <- quos(...) + if (any(names(dots) == "engine")) + stop("Use `set_engine` to supply the engine.", call. = FALSE) cl <- match.call(expand.dots = TRUE) eval_env <- rlang::env() eval_env$x <- x eval_env$y <- y fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object) - object$engine <- engine - object <- check_engine(object) - if (engine == "spark") + if (object$engine == "spark") stop( "spark objects can only be used with the formula interface to `fit` ", "with a spark data object.", call. = FALSE @@ -201,7 +197,7 @@ fit_xy.model_spec <- object <- get_method(object, engine = object$engine) check_installs(object) # TODO rewrite with pkgman - # TODO Should probably just load the namespace + load_libs(object, control$verbosity < 2) interfaces <- paste(fit_interface, object$method$fit$interface, sep = "_") diff --git a/R/linear_reg.R b/R/linear_reg.R index f2e37817f..b35259273 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -12,7 +12,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -25,7 +25,6 @@ #' represents the proportion of regularization that is used for the #' L2 penalty (i.e. weight decay, or ridge regression) versus L1 #' (the lasso) (`glmnet` and `spark` only). -#' #' @details #' The data given to the function are not saved and are only used #' to determine the _mode_ of the model. For `linear_reg`, the @@ -42,8 +41,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{lm} @@ -92,7 +90,7 @@ #' separately saved to disk. In a new session, the object can be #' reloaded and reattached to the `parsnip` object. #' -#' @seealso [varying()], [fit()] +#' @seealso [varying()], [fit()], [set_engine()] #' @examples #' linear_reg() #' # Parameters can be represented by a placeholder: @@ -102,10 +100,7 @@ linear_reg <- function(mode = "regression", penalty = NULL, - mixture = NULL, - ...) { - - others <- enquos(...) + mixture = NULL) { args <- list( penalty = enquo(penalty), @@ -119,13 +114,10 @@ linear_reg <- call. = FALSE ) - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - # write a constructor function out <- list( args = args, - others = others, + others = NULL, mode = mode, method = NULL, engine = NULL @@ -162,10 +154,7 @@ print.linear_reg <- function(x, ...) { update.linear_reg <- function(object, penalty = NULL, mixture = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) + fresh = FALSE) { args <- list( penalty = enquo(penalty), @@ -182,13 +171,6 @@ update.linear_reg <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 29fb60bf3..1b26ce12f 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -12,7 +12,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -39,8 +39,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{glm} @@ -100,10 +99,7 @@ logistic_reg <- function(mode = "classification", penalty = NULL, - mixture = NULL, - ...) { - - others <- enquos(...) + mixture = NULL) { args <- list( penalty = enquo(penalty), @@ -117,13 +113,10 @@ logistic_reg <- call. = FALSE ) - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - # write a constructor function out <- list( args = args, - others = others, + others = NULL, mode = mode, method = NULL, engine = NULL @@ -160,10 +153,7 @@ print.logistic_reg <- function(x, ...) { update.logistic_reg <- function(object, penalty = NULL, mixture = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) + fresh = FALSE) { args <- list( penalty = enquo(penalty), @@ -180,13 +170,6 @@ update.logistic_reg <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } diff --git a/R/mars.R b/R/mars.R index 6bc57b482..60cbe7228 100644 --- a/R/mars.R +++ b/R/mars.R @@ -17,7 +17,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -30,11 +30,7 @@ #' final model, including the intercept. #' @param prod_degree The highest possible interaction degree. #' @param prune_method The pruning method. -#' @details Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()]. -#' -#' The model can be created using the `fit()` function using the +#' @details The model can be created using the `fit()` function using the #' following _engines_: #' \itemize{ #' \item \pkg{R}: `"earth"` @@ -43,8 +39,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{earth} classification @@ -67,10 +62,7 @@ mars <- function(mode = "unknown", - num_terms = NULL, prod_degree = NULL, prune_method = NULL, - ...) { - - others <- enquos(...) + num_terms = NULL, prod_degree = NULL, prune_method = NULL) { args <- list( num_terms = enquo(num_terms), @@ -83,10 +75,7 @@ mars <- paste0("'", mars_modes, "'", collapse = ", "), call. = FALSE) - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - - out <- list(args = args, others = others, + out <- list(args = args, others = NULL, mode = mode, method = NULL, engine = NULL) class(out) <- make_classes("mars") out @@ -120,10 +109,7 @@ print.mars <- function(x, ...) { update.mars <- function(object, num_terms = NULL, prod_degree = NULL, prune_method = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) + fresh = FALSE) { args <- list( num_terms = enquo(num_terms), @@ -141,21 +127,17 @@ update.mars <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } # ------------------------------------------------------------------------------ #' @export -translate.mars <- function(x, engine, ...) { - +translate.mars <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'earth'` for translation.") + engine <- "earth" + } # If classification is being done, the `glm` options should be used. Check to # see if it is there and, if not, add the default value. if (x$mode == "classification") { diff --git a/R/mlp.R b/R/mlp.R index a323b89c2..b4fcbf2dd 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -18,7 +18,7 @@ #' #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (see above), the values are taken from the underlying model #' functions. One exception is `hidden_units` when `nnet::nnet` is used; that #' function's `size` argument has no default so a value of 5 units will be @@ -51,18 +51,13 @@ #' \item \pkg{keras}: `"keras"` #' } #' -#' Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()] (e.g. `hidden_units = expr(num_preds * 2)`). -#' #' An error is thrown if both `penalty` and `dropout` are specified for #' `keras` models. #' #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{keras} classification @@ -92,10 +87,7 @@ mlp <- function(mode = "unknown", hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, - activation = NULL, - ...) { - - others <- enquos(...) + activation = NULL) { args <- list( hidden_units = enquo(hidden_units), @@ -110,13 +102,10 @@ mlp <- paste0("'", mlp_modes, "'", collapse = ", "), call. = FALSE) - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - # write a constructor function - out <- list(args = args, others = others, + out <- list(args = args, others = NULL, mode = mode, method = NULL, engine = NULL) - # TODO: make_classes has wrong order; go from specific to general + class(out) <- make_classes("mlp") out } @@ -155,9 +144,7 @@ update.mlp <- function(object, hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) + fresh = FALSE) { args <- list( hidden_units = enquo(hidden_units), @@ -178,20 +165,17 @@ update.mlp <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } # ------------------------------------------------------------------------------ #' @export -translate.mlp <- function(x, engine, ...) { +translate.mlp <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'keras'` for translation.") + engine <- "keras" + } if (engine == "nnet") { if(isTRUE(is.null(quo_get_expr(x$args$hidden_units)))) { @@ -219,10 +203,6 @@ check_args.mlp <- function(object) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$hidden_units)) - if (args$hidden_units < 2) - stop("There must be at least two hidden units", call. = FALSE) - if (is.numeric(args$penalty)) if (args$penalty < 0) stop("The amount of weight decay must be >= 0.", call. = FALSE) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index d9505cf57..561eec506 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -12,7 +12,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -38,8 +38,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{glmnet} @@ -83,9 +82,7 @@ multinom_reg <- function(mode = "classification", penalty = NULL, - mixture = NULL, - ...) { - others <- enquos(...) + mixture = NULL) { args <- list( penalty = enquo(penalty), @@ -99,13 +96,10 @@ multinom_reg <- call. = FALSE ) - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - # write a constructor function out <- list( args = args, - others = others, + others = NULL, mode = mode, method = NULL, engine = NULL @@ -142,10 +136,7 @@ print.multinom_reg <- function(x, ...) { update.multinom_reg <- function(object, penalty = NULL, mixture = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) - + fresh = FALSE) { args <- list( penalty = enquo(penalty), mixture = enquo(mixture) @@ -161,13 +152,6 @@ update.multinom_reg <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 8b374b7f6..cc6e66b57 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -19,7 +19,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update()` can be used #' in lieu of recreating the object from scratch. @@ -49,8 +49,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{kknn} (classification or regression) @@ -67,16 +66,13 @@ #' @seealso [varying()], [fit()] #' #' @examples -#' nearest_neighbor() +#' nearest_neighbor(neighbors = 11) #' #' @export nearest_neighbor <- function(mode = "unknown", neighbors = NULL, weight_func = NULL, - dist_power = NULL, - ...) { - others <- enquos(...) - + dist_power = NULL) { args <- list( neighbors = enquo(neighbors), weight_func = enquo(weight_func), @@ -90,13 +86,10 @@ nearest_neighbor <- function(mode = "unknown", call. = FALSE) } - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] - # write a constructor function - out <- list(args = args, others = others, + out <- list(args = args, others = NULL, mode = mode, method = NULL, engine = NULL) - # TODO: make_classes has wrong order; go from specific to general + class(out) <- make_classes("nearest_neighbor") out } @@ -121,10 +114,7 @@ update.nearest_neighbor <- function(object, neighbors = NULL, weight_func = NULL, dist_power = NULL, - fresh = FALSE, - ...) { - - others <- enquos(...) + fresh = FALSE) { args <- list( neighbors = enquo(neighbors), @@ -142,13 +132,6 @@ update.nearest_neighbor <- function(object, object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } diff --git a/R/predict.R b/R/predict.R index 5dfd42823..dc80c08d7 100644 --- a/R/predict.R +++ b/R/predict.R @@ -47,7 +47,7 @@ #' #' Quantile predictions return a tibble with a column `.pred`, which is #' a list-column. Each list element contains a tibble with columns -#' `.pred` and `.quantile` (and perhaps others). +#' `.pred` and `.quantile` (and perhaps other columns). #' #' Using `type = "raw"` with `predict.model_fit` (or using #' `predict_raw`) will return the unadulterated results of the diff --git a/R/rand_forest.R b/R/rand_forest.R index 3d81e897b..229f6e681 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -15,7 +15,7 @@ #' } #' These arguments are converted to their specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to their defaults +#' set using `set_engine`. If left to their defaults #' here (`NULL`), the values are taken from the underlying model #' functions. If parameters need to be modified, `update` can be used #' in lieu of recreating the object from scratch. @@ -38,15 +38,10 @@ #' \item \pkg{Spark}: `"spark"` #' } #' -#' Main parameter arguments (and those in `...`) can avoid -#' evaluation until the underlying function is executed by wrapping the -#' argument in [rlang::expr()] (e.g. `mtry = expr(floor(sqrt(p)))`). -#' #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are:: #' #' \pkg{ranger} classification @@ -100,10 +95,7 @@ #' @export rand_forest <- - function(mode = "unknown", - mtry = NULL, trees = NULL, min_n = NULL, ...) { - - others <- enquos(...) + function(mode = "unknown", mtry = NULL, trees = NULL, min_n = NULL) { args <- list( mtry = enquo(mtry), @@ -117,13 +109,10 @@ rand_forest <- paste0("'", rand_forest_modes, "'", collapse = ", "), call. = FALSE) - no_value <- !vapply(others, null_value, logical(1)) - others <- others[no_value] - # write a constructor function - out <- list(args = args, others = others, + out <- list(args = args, others = NULL, mode = mode, method = NULL, engine = NULL) - # TODO: make_classes has wrong order; go from specific to general + class(out) <- make_classes("rand_forest") out } @@ -156,9 +145,7 @@ print.rand_forest <- function(x, ...) { update.rand_forest <- function(object, mtry = NULL, trees = NULL, min_n = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) + fresh = FALSE) { args <- list( mtry = enquo(mtry), @@ -176,21 +163,18 @@ update.rand_forest <- if (length(args) > 0) object$args[names(args)] <- args } - - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } # ------------------------------------------------------------------------------ #' @export -translate.rand_forest <- function(x, engine, ...) { +translate.rand_forest <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'ranger'` for translation.") + engine <- "ranger" + } + x <- translate.default(x, engine, ...) # slightly cleaner code using @@ -217,7 +201,7 @@ translate.rand_forest <- function(x, engine, ...) { } # add checks to error trap or change things for this method - if (x$engine == "ranger") { + if (engine == "ranger") { if (any(names(arg_vals) == "importance")) if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) stop("`importance` should be a character value. See ?ranger::ranger.", diff --git a/R/set_engine.R b/R/set_engine.R new file mode 100644 index 000000000..9234352e6 --- /dev/null +++ b/R/set_engine.R @@ -0,0 +1,31 @@ +#' Declare a computational engine and specific arguments +#' +#' `set_engine` is used to specify which package or system will be used +#' to fit the model, along with any arguments specific to that software. +#' +#' @param object A model specification. +#' @param engine A character string for the software that should +#' be used to fit the model. This is highly dependent on the type +#' of model (e.g. linear regression, random forest, etc.). +#' @param ... Any optional arguments associated with the chosen computational +#' engine. These are captured as quosures and can be `varying()`. +#' @return An updated model specification. +#' @examples +#' # First, set general arguments using the standardized names +#' mod <- +#' logistic_reg(mixture = 1/3) %>% +#' # now say how you want to fit the model and another other options +#' set_engine("glmnet", nlambda = 10) +#' translate(mod, engine = "glmnet") +#' @export +set_engine <- function(object, engine, ...) { + if (!is.character(engine) | length(engine) != 1) + stop("`engine` should be a single character value.", call. = FALSE) + + object$engine <- engine + object <- parsnip:::check_engine(object) + + + object$others <- enquos(...) + object +} \ No newline at end of file diff --git a/R/surv_reg.R b/R/surv_reg.R index 29c3489ab..42bac513d 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -9,7 +9,7 @@ #' } #' This argument is converted to its specific names at the #' time that the model is fit. Other options and argument can be -#' set using the `...` slot. If left to its default +#' set using `set_engine`. If left to its default #' here (`NULL`), the value is taken from the underlying model #' functions. #' @@ -42,8 +42,7 @@ #' @section Engine Details: #' #' Engines may have pre-set default arguments when executing the -#' model fit call. These can be changed by using the `...` -#' argument to pass in the preferred values. For this type of +#' model fit call. For this type of #' model, the template of the fit calls are: #' #' \pkg{flexsurv} @@ -67,11 +66,7 @@ #' surv_reg(dist = varying()) #' #' @export -surv_reg <- - function(mode = "regression", - dist = NULL, - ...) { - others <- enquos(...) +surv_reg <- function(mode = "regression", dist = NULL) { args <- list( dist = enquo(dist) @@ -84,13 +79,11 @@ surv_reg <- call. = FALSE ) - no_value <- !vapply(others, is.null, logical(1)) - others <- others[no_value] # write a constructor function out <- list( args = args, - others = others, + others = NULL, mode = mode, method = NULL, engine = NULL @@ -128,12 +121,7 @@ print.surv_reg <- function(x, ...) { #' @method update surv_reg #' @rdname surv_reg #' @export -update.surv_reg <- - function(object, - dist = NULL, - fresh = FALSE, - ...) { - others <- enquos(...) +update.surv_reg <- function(object, dist = NULL, fresh = FALSE) { args <- list( dist = enquo(dist) @@ -149,13 +137,6 @@ update.surv_reg <- object$args[names(args)] <- args } - if (length(others) > 0) { - if (fresh) - object$others <- others - else - object$others[names(others)] <- others - } - object } @@ -163,7 +144,11 @@ update.surv_reg <- # ------------------------------------------------------------------------------ #' @export -translate.surv_reg <- function(x, engine, ...) { +translate.surv_reg <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'survreg'` for translation.") + engine <- "survreg" + } x <- translate.default(x, engine, ...) x } diff --git a/R/translate.R b/R/translate.R index 7f1102c57..af577705a 100644 --- a/R/translate.R +++ b/R/translate.R @@ -42,8 +42,10 @@ translate <- function (x, ...) #' @importFrom utils getFromNamespace #' @importFrom purrr list_modify #' @export -translate.default <- function(x, engine, ...) { +translate.default <- function(x, engine = x$engine, ...) { check_empty_ellipse(...) + if (is.null(engine)) + stop("Please set an engine.", call. = FALSE) x$engine <- engine x <- check_engine(x) @@ -78,11 +80,10 @@ translate.default <- function(x, engine, ...) { x$method$fit$args <- c(protected, actual_args, x$others, x$defaults) - # put in correct order x } -get_method <- function(x, engine, ...) { +get_method <- function(x, engine = x$engine, ...) { check_empty_ellipse(...) x$engine <- engine x <- check_engine(x) diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html index ed83c3d28..7243ec47b 100644 --- a/docs/articles/articles/Classification.html +++ b/docs/articles/articles/Classification.html @@ -104,27 +104,26 @@

Classification Example

#> ── Attaching packages ──────────────────────────── tidymodels 0.0.1.9000 ── #> ✔ broom 0.5.0.9001 ✔ purrr 0.2.5 #> ✔ dials 0.0.1.9000 ✔ recipes 0.1.3.9002 -#> ✔ dplyr 0.7.6 ✔ rsample 0.0.2 +#> ✔ dplyr 0.7.7 ✔ rsample 0.0.2 #> ✔ infer 0.3.1 ✔ yardstick 0.0.1.9000 #> ✔ probably 0.0.0.9000 -#> Warning: package 'dplyr' was built under R version 3.5.1 -#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ── -#> ✖ probably::as.factor() masks base::as.factor() -#> ✖ probably::as.ordered() masks base::as.ordered() -#> ✖ purrr::discard() masks scales::discard() -#> ✖ rsample::fill() masks tidyr::fill() -#> ✖ dplyr::filter() masks stats::filter() -#> ✖ dplyr::lag() masks stats::lag() -#> ✖ rsample::prepper() masks recipes::prepper() -#> ✖ recipes::step() masks stats::step() - -data(credit_data) - -set.seed(7075) -data_split <- initial_split(credit_data, strata = "Status", p = 0.75) - -credit_train <- training(data_split) -credit_test <- testing(data_split) +#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ── +#> ✖ probably::as.factor() masks base::as.factor() +#> ✖ probably::as.ordered() masks base::as.ordered() +#> ✖ purrr::discard() masks scales::discard() +#> ✖ rsample::fill() masks tidyr::fill() +#> ✖ dplyr::filter() masks stats::filter() +#> ✖ dplyr::lag() masks stats::lag() +#> ✖ rsample::prepper() masks recipes::prepper() +#> ✖ recipes::step() masks stats::step() + +data(credit_data) + +set.seed(7075) +data_split <- initial_split(credit_data, strata = "Status", p = 0.75) + +credit_train <- training(data_split) +credit_test <- testing(data_split)

A single hidden layer neural network will be used to predict a person’s credit status. To do so, the columns of the predictor matrix should be numeric and on a common scale. recipes will be used to do so.

credit_rec <- recipe(Status ~ ., data = credit_train) %>%
   step_knnimpute(Home, Job, Marital, Income, Assets, Debt) %>%
@@ -137,34 +136,32 @@ 

Classification Example

test_normalized <- bake(credit_rec, newdata = credit_test, all_predictors())

keras will be used to fit a model with 5 hidden units and uses a 10% dropout rate to regularize the model. At each training iteration (aka epoch) a random 20% of the data will be used to measure the cross-entropy of the model.

-
library(keras)
-
-set.seed(57974)
-nnet_fit <- mlp(
-  epochs = 100, hidden_units = 5, dropout = 0.1, 
-  others = list(verbose = 0, validation_split = .20)
-) %>%
-  parsnip::fit.model_spec(Status ~ ., data = juice(credit_rec), engine = "keras")
-
-nnet_fit
-#> parsnip model object
-#> 
-#> Model
-#> ___________________________________________________________________________
-#> Layer (type)                     Output Shape                  Param #     
-#> ===========================================================================
-#> dense_1 (Dense)                  (None, 5)                     115         
+
+#> dense_3 (Dense)                  (None, 2)                     12          
+#> ===========================================================================
+#> Total params: 157
+#> Trainable params: 157
+#> Non-trainable params: 0
+#> ___________________________________________________________________________

In parsnip, the predict function is only appropriate for numeric outcomes while predict_class and predict_classprob can be used for categorical outcomes.

test_results <- credit_test %>%
   select(Status) %>%
@@ -178,17 +175,17 @@ 

Classification Example

#> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 roc_auc 0.829 +#> 1 roc_auc 0.825 test_results %>% accuracy(truth = Status, estimate = `nnet class`) #> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 accuracy 0.809 +#> 1 accuracy 0.801 test_results %>% conf_mat(truth = Status, estimate = `nnet class`) #> Truth #> Prediction bad good -#> bad 182 82 -#> good 131 718
+#> bad 174 82 +#> good 139 718 +#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ── +#> ✖ probably::as.factor() masks base::as.factor() +#> ✖ probably::as.ordered() masks base::as.ordered() +#> ✖ dplyr::combine() masks randomForest::combine() +#> ✖ purrr::discard() masks scales::discard() +#> ✖ rsample::fill() masks tidyr::fill() +#> ✖ dplyr::filter() masks stats::filter() +#> ✖ dplyr::lag() masks stats::lag() +#> ✖ ggplot2::margin() masks randomForest::margin() +#> ✖ rsample::prepper() masks recipes::prepper() +#> ✖ recipes::step() masks stats::step() + +set.seed(4595) +data_split <- initial_split(ames, strata = "Sale_Price", p = 0.75) + +ames_train <- training(data_split) +ames_test <- testing(data_split)

Random Forests

@@ -141,30 +140,31 @@

parsnip gives two different interfaces to the models: the formula and non-formula interfaces. Let’s start with the non-formula interface:

+rf_xy_fit <- + rf_defaults %>% + set_engine("ranger") %>% + fit_xy( + x = ames_train[, preds], + y = log10(ames_train$Sale_Price) + ) +rf_xy_fit +#> parsnip model object +#> +#> Ranger result +#> +#> Call: +#> ranger::ranger(formula = formula, data = data, num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1)) +#> +#> Type: Regression +#> Number of trees: 500 +#> Sample size: 2199 +#> Number of independent variables: 5 +#> Mtry: 2 +#> Target node size: 5 +#> Variable importance mode: none +#> Splitrule: variance +#> OOB prediction error (MSE): 0.00866 +#> R squared (OOB): 0.727

The non-formula interface doesn’t do anything to the predictors before giving it to the underlying model function. This particular model does not require indicator variables to be create prior to the model (note that the output shows “Number of independent variables: 5”).

For regression models, the basic predict method can be used and returns a tibble with a column named .pred:

test_results <- ames_test %>%
@@ -204,10 +204,10 @@ 

Now, for illustration, let’s use the formula method using some new parameter values:

-

Suppose that there was some feature in the randomForest package that we’d like to evaluate. To do so, the only part of the syntaxt that needs to change is the engine argument:

+

Suppose that there was some feature in the randomForest package that we’d like to evaluate. To do so, the only part of the syntaxt that needs to change is the set_engine argument:

Look at the formula code that was printed out, one function uses the argument name ntree and the other uses num.trees. parsnip doesn’t require you to know the specific names of the main arguments.

Now suppose that we want to modify the value of mtry based on the number of predictors in the data. Usually, the default value would be floor(sqrt(num_predictors)). To use a pure bagging model would require an mtry value equal to the total number of parameters. There may be cases where you may not know how many predictors are going to be present (perhaps due to the generation of indicator variables or a variable filter) so that might be difficult to know exactly.

-

When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.

+

When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.

Two relevant descriptors for what we are about to do are:

If penalty were not specified, all of the lambda values would be computed.

To get the predictions for this specific value of lambda (aka penalty):

# First, get the processed version of the test set predictors:
diff --git a/docs/articles/articles/Scratch.html b/docs/articles/articles/Scratch.html
index 537074874..007ae1264 100644
--- a/docs/articles/articles/Scratch.html
+++ b/docs/articles/articles/Scratch.html
@@ -153,11 +153,11 @@ 

  • The mode. If the model can do more than one mode, you might default this to “unknown”. In our case, since it is only a classification model, it makes sense to default it to that mode.
  • The argument names (sub_classes here). These should be defaulted to NULL.
  • -... is used to pass in other arguments to the underlying model fit functions.
  • +... are not used in the main model function.

    A basic version of the function is:

    + args <- list(sub_classes = rlang::enquo(sub_classes)) + + # Save some empty slots for future parts of the specification + out <- list(args = args, others = NULL, + mode = mode, method = NULL, engine = NULL) + + # set classes in the correct order + class(out) <- make_classes("mixture_da") + out + }

    This is pretty simple since the data are not exposed to this function.

    @@ -236,7 +232,7 @@

  • func is the prediction function (in the same format as above). In many cases, packages have a predict method for their model’s class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to predict with no associated package.
  • -args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.
  • +args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.

    Let’s try it on the iris data:

    set.seed(4622)
     iris_split <- initial_split(iris, prop = 0.90)
    @@ -307,64 +304,65 @@ 

    mda_spec <- mixture_da(sub_classes = 2) mda_fit <- mda_spec %>% - fit(Species ~ ., data = iris_train, engine = "mda") -mda_fit -#> parsnip model object -#> -#> Call: -#> mda::mda(formula = formula, data = data, sub_classes = ~2) -#> -#> Dimension: 4 -#> -#> Percent Between-Group Variance Explained: -#> v1 v2 v3 v4 -#> 94.9 97.9 99.8 100.0 -#> -#> Degrees of Freedom (per dimension): 5 -#> -#> Training Misclassification Error: 0.0221 ( N = 136 ) -#> -#> Deviance: 12.3 - -predict(mda_fit, new_data = iris_test) %>% - bind_cols(iris_test %>% select(Species)) -#> # A tibble: 14 x 2 -#> .pred_class Species -#> <fct> <fct> -#> 1 setosa setosa -#> 2 setosa setosa -#> 3 setosa setosa -#> 4 versicolor versicolor -#> 5 versicolor versicolor -#> 6 versicolor versicolor -#> 7 versicolor versicolor -#> 8 versicolor versicolor -#> 9 versicolor versicolor -#> 10 versicolor versicolor -#> 11 versicolor versicolor -#> 12 virginica virginica -#> 13 virginica virginica -#> 14 virginica virginica - -predict(mda_fit, new_data = iris_test, type = "prob") %>% - bind_cols(iris_test %>% select(Species)) -#> # A tibble: 14 x 4 -#> .pred_setosa .pred_versicolor .pred_virginica Species -#> <dbl> <dbl> <dbl> <fct> -#> 1 1.00e+ 0 2.62e-32 7.10e-65 setosa -#> 2 1.00e+ 0 1.36e-25 2.36e-56 setosa -#> 3 1.00e+ 0 9.11e-29 1.33e-60 setosa -#> 4 1.76e-38 10.00e- 1 1.97e- 7 versicolor -#> 5 5.64e-36 9.95e- 1 5.03e- 3 versicolor -#> 6 6.84e-22 10.00e- 1 9.83e- 9 versicolor -#> 7 2.54e-37 9.22e- 1 7.83e- 2 versicolor -#> 8 2.70e-37 9.99e- 1 1.34e- 3 versicolor -#> 9 1.81e-37 8.06e- 1 1.94e- 1 versicolor -#> 10 9.83e-35 9.93e- 1 7.27e- 3 versicolor -#> 11 4.04e-37 9.97e- 1 3.00e- 3 versicolor -#> 12 1.93e-55 1.44e- 1 8.56e- 1 virginica -#> 13 1.21e-50 4.19e- 1 5.81e- 1 virginica -#> 14 2.08e-50 2.07e- 1 7.93e- 1 virginica

    + set_engine("mda") %>% + fit(Species ~ ., data = iris_train) +mda_fit +#> parsnip model object +#> +#> Call: +#> mda::mda(formula = formula, data = data, sub_classes = ~2) +#> +#> Dimension: 4 +#> +#> Percent Between-Group Variance Explained: +#> v1 v2 v3 v4 +#> 94.9 97.9 99.8 100.0 +#> +#> Degrees of Freedom (per dimension): 5 +#> +#> Training Misclassification Error: 0.0221 ( N = 136 ) +#> +#> Deviance: 12.3 + +predict(mda_fit, new_data = iris_test) %>% + bind_cols(iris_test %>% select(Species)) +#> # A tibble: 14 x 2 +#> .pred_class Species +#> <fct> <fct> +#> 1 setosa setosa +#> 2 setosa setosa +#> 3 setosa setosa +#> 4 versicolor versicolor +#> 5 versicolor versicolor +#> 6 versicolor versicolor +#> 7 versicolor versicolor +#> 8 versicolor versicolor +#> 9 versicolor versicolor +#> 10 versicolor versicolor +#> 11 versicolor versicolor +#> 12 virginica virginica +#> 13 virginica virginica +#> 14 virginica virginica + +predict(mda_fit, new_data = iris_test, type = "prob") %>% + bind_cols(iris_test %>% select(Species)) +#> # A tibble: 14 x 4 +#> .pred_setosa .pred_versicolor .pred_virginica Species +#> <dbl> <dbl> <dbl> <fct> +#> 1 1.00e+ 0 2.62e-32 7.10e-65 setosa +#> 2 1.00e+ 0 1.36e-25 2.36e-56 setosa +#> 3 1.00e+ 0 9.11e-29 1.33e-60 setosa +#> 4 1.76e-38 10.00e- 1 1.97e- 7 versicolor +#> 5 5.64e-36 9.95e- 1 5.03e- 3 versicolor +#> 6 6.84e-22 10.00e- 1 9.83e- 9 versicolor +#> 7 2.54e-37 9.22e- 1 7.83e- 2 versicolor +#> 8 2.70e-37 9.99e- 1 1.34e- 3 versicolor +#> 9 1.81e-37 8.06e- 1 1.94e- 1 versicolor +#> 10 9.83e-35 9.93e- 1 7.27e- 3 versicolor +#> 11 4.04e-37 9.97e- 1 3.00e- 3 versicolor +#> 12 1.93e-55 1.44e- 1 8.56e- 1 virginica +#> 13 1.21e-50 4.19e- 1 5.81e- 1 virginica +#> 14 2.08e-50 2.07e- 1 7.93e- 1 virginica

    @@ -385,25 +383,26 @@

    and so on. These can be accomodated via predict.model_fit using different type arguments.

    However, there are some models (e.g. glmnet, plsr, Cubist, etc.) that can make predictions for different models from the same fitted model object. The regular predict method requires prediction from a single model but the multi_predict can. The guideline is to always return the same number of rows as in new_data. This means that the .pred column is a list-column of tibbles.

    For example, for a multinomial glmnet model, we leave penalty unspecified when fitting and get predictions on a sequence of values:

    - +

    This can be easily expanded to remove the list columns:

    +logistic_reg() %>% + set_engine("glm", family = stats::binomial(link = "probit")) %>% + translate() +#> Logistic Regression Model Specification (classification) +#> +#> Engine-Specific Arguments: +#> family = stats::binomial(link = "probit") +#> +#> Computational engine: glm +#> +#> Model fit template: +#> stats::glm(formula = missing_arg(), data = missing_arg(), weights = missing_arg(), +#> family = stats::binomial(link = "probit"))

    That’s what defaults are for.

    Note that I wrapped binomial inside of expr. If I didn’t, it would substitute the results of executing binomial inside of the expression (and that’s a mess). Using namespaces is a good idea here.

    @@ -460,7 +460,7 @@

    The translate function can be used to check values or set defaults once the model’s mode is known. To do this, you can create a model-specific S3 method that first calls the general method (translate.model_spec) and then makes modifications or conducts error traps.

    For example, the ranger and randomForest package functions have arguments for calculating importance. One is a logical and the other is a string. Since this is likely to lead to a bunch of frustration and GH issues, we can put in a check:

    # Simplified version
    -translate.rand_forest <- function (x, engine, ...){
    +translate.rand_forest <- function (x, engine = x$engine, ...){
       # Run the general method to get the real arguments in place
       x <- translate.default(x, engine, ...)
       
    @@ -468,7 +468,7 @@ 

    arg_vals <- x$method$fit$args # Check and see if they make sense for the engine and/or mode: - if (x$engine == "ranger") { + if (engine == "ranger") { if (any(names(arg_vals) == "importance")) # We want to check the type of `importance` but it is a quosure. We first # get the expression. It is is logical, the value of `quo_get_expr` will diff --git a/docs/articles/parsnip_Intro.html b/docs/articles/parsnip_Intro.html index 9cf95f251..f9a75d439 100644 --- a/docs/articles/parsnip_Intro.html +++ b/docs/articles/parsnip_Intro.html @@ -152,25 +152,23 @@

    The arguments to the default function are:

    -

    However, there might be other arguments that you would like to change or allow to vary. These are accessible using the ... slot. This is a named list of arguments in the form of the underlying function being called. For example, ranger has an option to set the internal random number seed. To set this to a specific value:

    - +

    However, there might be other arguments that you would like to change or allow to vary. These are accessible using set_engine. For example, ranger has an option to set the internal random number seed. To set this to a specific value:

    + +#> Computational engine: ranger

    @@ -184,7 +182,8 @@

    For example, rf_with_seed above is not ready for fitting due the varying() parameter. We can set that parameter’s value and then create the model fit:

    rf_with_seed %>% 
       set_args(mtry = 4) %>% 
    -  fit(mpg ~ ., data = mtcars, engine = "ranger")
    + set_engine("ranger") %>% + fit(mpg ~ ., data = mtcars)

    #> parsnip model object
     #> 
     #> Ranger result
    @@ -206,7 +205,8 @@ 

    set.seed(56982)
     rf_with_seed %>% 
       set_args(mtry = 4) %>% 
    -  fit(mpg ~ ., data = mtcars, engine = "randomForest")
    + set_engine("randomForest") %>% + fit(mpg ~ ., data = mtcars)

    #> parsnip model object
     #> 
     #> 
    diff --git a/docs/reference/boost_tree.html b/docs/reference/boost_tree.html
    index 6a7c2bc3e..52fb13806 100644
    --- a/docs/reference/boost_tree.html
    +++ b/docs/reference/boost_tree.html
    @@ -49,7 +49,7 @@
     sample_size: The amount of data exposed to the fitting routine.
     These arguments are converted to their specific names at the
     time that the model is fit. Other options and argument can be
    -set using the  ... slot. If left to their defaults
    +set using the  set_engine function. If left to their defaults
     here (NULL), the values are taken from the underlying model
     functions.  If parameters need to be modified, update can be used
     in lieu of recreating the object from scratch." />
    @@ -167,7 +167,7 @@ 

    General Interface for Boosted Trees

  • sample_size: The amount of data exposed to the fitting routine.

  • These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to their defaults +set using the set_engine function. If left to their defaults here (NULL), the values are taken from the underlying model functions. If parameters need to be modified, update can be used in lieu of recreating the object from scratch.

    @@ -176,12 +176,12 @@

    General Interface for Boosted Trees

    boost_tree(mode = "unknown", mtry = NULL, trees = NULL,
       min_n = NULL, tree_depth = NULL, learn_rate = NULL,
    -  loss_reduction = NULL, sample_size = NULL, ...)
    +  loss_reduction = NULL, sample_size = NULL)
     
     # S3 method for boost_tree
     update(object, mtry = NULL, trees = NULL,
       min_n = NULL, tree_depth = NULL, learn_rate = NULL,
    -  loss_reduction = NULL, sample_size = NULL, fresh = FALSE, ...)
    + loss_reduction = NULL, sample_size = NULL, fresh = FALSE)

    Arguments

    @@ -228,14 +228,6 @@

    Arg

    - - - - @@ -262,9 +254,6 @@

    Details
  • R: "xgboost", "C5.0"

  • Spark: "spark"

  • -

    Main parameter arguments (and those in ...) can avoid -evaluation until the underlying function is executed by wrapping the -argument in rlang::expr() (e.g. mtry = expr(floor(sqrt(p)))).

    Note

    @@ -285,9 +274,8 @@

    xgboost classification

     parsnip::xgb_train(x = missing_arg(), y = missing_arg(), nthread = 1, 
    @@ -315,7 +303,7 @@ 

    See also

    - +

    Examples

    diff --git a/docs/reference/fit.html b/docs/reference/fit.html index 188e8ab72..50081f8b4 100644 --- a/docs/reference/fit.html +++ b/docs/reference/fit.html @@ -138,11 +138,11 @@

    Fit a Model Specification to a Dataset

    # S3 method for model_spec
     fit(object, formula = NULL, data = NULL,
    -  engine = object$engine, control = fit_control(), ...)
    +  control = fit_control(), ...)
     
     # S3 method for model_spec
     fit_xy(object, x = NULL, y = NULL,
    -  engine = object$engine, control = fit_control(), ...)
    + control = fit_control(), ...)

    Arguments

    An number for the number (or proportion) of data that is exposed to the fitting routine. For xgboost, the sampling is done at at each iteration while C5.0 samples once during traning.

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object
    @@ -163,12 +163,6 @@

    Arg below). A data frame containing all relevant variables (e.g. outcome(s), predictors, case weights, etc). Note: when needed, a named argument should be used.

    - -

    - - @@ -179,8 +173,7 @@

    Arg

    +passed using set_engine.

    @@ -231,7 +224,7 @@

    Examp
    # Although `glm` only has a formula interface, different # methods for specifying the model can be used -library(dplyr)
    #> Warning: package ‘dplyr’ was built under R version 3.5.1
    #> +library(dplyr)
    #> #> Attaching package: ‘dplyr’
    #> The following objects are masked from ‘package:stats’: #> #> filter, lag
    #> The following objects are masked from ‘package:base’: @@ -246,37 +239,13 @@

    Examp lm_mod %>% fit(Class ~ funded_amnt + int_rate, data = lending_club, - engine = "glm") - + engine = "glm")

    #> Error in if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) stop("spark objects can only be used with the formula interface to `fit` ", "with a spark data object.", call. = FALSE): missing value where TRUE/FALSE needed
    using_xy <- lm_mod %>% fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], y = lending_club$Class, - engine = "glm") - -using_formula
    #> parsnip model object -#> -#> -#> Call: stats::glm(formula = formula, family = stats::binomial, data = data) -#> -#> Coefficients: -#> (Intercept) funded_amnt int_rate -#> 5.131e+00 2.767e-06 -1.586e-01 -#> -#> Degrees of Freedom: 9856 Total (i.e. Null); 9854 Residual -#> Null Deviance: 4055 -#> Residual Deviance: 3698 AIC: 3704
    using_xy
    #> parsnip model object -#> -#> -#> Call: stats::glm(formula = formula, family = stats::binomial, data = data) -#> -#> Coefficients: -#> (Intercept) funded_amnt int_rate -#> 5.131e+00 2.767e-06 -1.586e-01 -#> -#> Degrees of Freedom: 9856 Total (i.e. Null); 9854 Residual -#> Null Deviance: 4055 -#> Residual Deviance: 3698 AIC: 3704
    + engine = "glm")
    #> Error in if (object$engine == "spark") stop("spark objects can only be used with the formula interface to `fit` ", "with a spark data object.", call. = FALSE): argument is of length zero
    +using_formula
    #> Error in eval(expr, envir, enclos): object 'using_formula' not found
    using_xy
    #> Error in eval(expr, envir, enclos): object 'using_xy' not found
    -
    linear_reg(mode = "regression", penalty = NULL, mixture = NULL, ...)
    +    
    linear_reg(mode = "regression", penalty = NULL, mixture = NULL)
     
     # S3 method for linear_reg
     update(object, penalty = NULL, mixture = NULL,
    -  fresh = FALSE, ...)
    + fresh = FALSE)

    Arguments

    engine

    A character string for the software that should -be used to fit the model. This is highly dependent on the type -of model (e.g. linear regression, random forest, etc.).

    control...

    Not currently used; values passed here will be ignored. Other options required to fit the model should be -passed using the others argument in the original model -specification.

    x
    @@ -183,14 +183,6 @@

    Arg represents the proportion of regularization that is used for the L2 penalty (i.e. weight decay, or ridge regression) versus L1 (the lasso) (glmnet and spark only).

    - -

    - - @@ -234,8 +226,7 @@

    lm

    @@ -274,7 +265,7 @@ 

    See also

    - +

    Examples

    diff --git a/docs/reference/logistic_reg.html b/docs/reference/logistic_reg.html index 5f26fb82c..608b1839f 100644 --- a/docs/reference/logistic_reg.html +++ b/docs/reference/logistic_reg.html @@ -41,7 +41,7 @@ the model. Note that this will be ignored for some engines. These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to their defaults +set using set_engine. If left to their defaults here (NULL), the values are taken from the underlying model functions. If parameters need to be modified, update can be used in lieu of recreating the object from scratch." /> @@ -151,19 +151,18 @@

    General Interface for Logistic Regression Models

    the model. Note that this will be ignored for some engines.

    These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to their defaults +set using set_engine. If left to their defaults here (NULL), the values are taken from the underlying model functions. If parameters need to be modified, update can be used in lieu of recreating the object from scratch.

    -
    logistic_reg(mode = "classification", penalty = NULL, mixture = NULL,
    -  ...)
    +    
    logistic_reg(mode = "classification", penalty = NULL, mixture = NULL)
     
     # S3 method for logistic_reg
     update(object, penalty = NULL, mixture = NULL,
    -  fresh = FALSE, ...)
    + fresh = FALSE)

    Arguments

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object
    @@ -184,14 +183,6 @@

    Arg represents the proportion of regularization that is used for the L2 penalty (i.e. weight decay, or ridge regression) versus L1 (the lasso) (glmnet and spark only).

    - -

    - - @@ -233,8 +224,7 @@

    glm

    diff --git a/docs/reference/mars.html b/docs/reference/mars.html
    index b3b591335..31f88a494 100644
    --- a/docs/reference/mars.html
    +++ b/docs/reference/mars.html
    @@ -44,7 +44,7 @@
     in ?earth.
     These arguments are converted to their specific names at the
     time that the model is fit. Other options and argument can be
    -set using the  ... slot. If left to their defaults
    +set using set_engine. If left to their defaults
     here (NULL), the values are taken from the underlying model
     functions. If parameters need to be modified, update can be used
     in lieu of recreating the object from scratch." />
    @@ -157,7 +157,7 @@ 

    General Interface for MARS

    in ?earth.

    These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to their defaults +set using set_engine. If left to their defaults here (NULL), the values are taken from the underlying model functions. If parameters need to be modified, update can be used in lieu of recreating the object from scratch.

    @@ -165,11 +165,11 @@

    General Interface for MARS

    mars(mode = "unknown", num_terms = NULL, prod_degree = NULL,
    -  prune_method = NULL, ...)
    +  prune_method = NULL)
     
     # S3 method for mars
     update(object, num_terms = NULL, prod_degree = NULL,
    -  prune_method = NULL, fresh = FALSE, ...)
    + prune_method = NULL, fresh = FALSE)

    Arguments

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object
    @@ -193,14 +193,6 @@

    Arg

    - - - - @@ -214,10 +206,7 @@

    Arg

    Details

    -

    Main parameter arguments (and those in ...) can avoid -evaluation until the underlying function is executed by wrapping the -argument in rlang::expr().

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    • R: "earth"

    @@ -226,8 +215,7 @@

    earth classification

    diff --git a/docs/reference/mlp.html b/docs/reference/mlp.html
    index 4976bf489..1c02241df 100644
    --- a/docs/reference/mlp.html
    +++ b/docs/reference/mlp.html
    @@ -165,12 +165,11 @@ 

    General Interface for Single Layer Neural Network

    mlp(mode = "unknown", hidden_units = NULL, penalty = NULL,
    -  dropout = NULL, epochs = NULL, activation = NULL, ...)
    +  dropout = NULL, epochs = NULL, activation = NULL)
     
     # S3 method for mlp
     update(object, hidden_units = NULL, penalty = NULL,
    -  dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE,
    -  ...)
    + dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE)

    Arguments

    prune_method

    The pruning method.

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object

    A MARS model specification.

    @@ -206,14 +205,6 @@

    Arg function between the hidden and output layers is automatically set to either "linear" or "softmax" depending on the type of outcome. Possible values are: "linear", "softmax", "relu", and "elu"

    - -

    - - @@ -230,7 +221,7 @@

    Details

    These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to their defaults +set using set_engine. If left to their defaults here (see above), the values are taken from the underlying model functions. One exception is hidden_units when nnet::nnet is used; that function's size argument has no default so a value of 5 units will be @@ -243,18 +234,14 @@

    Details
  • R: "nnet"

  • keras: "keras"

  • -

    Main parameter arguments (and those in ...) can avoid -evaluation until the underlying function is executed by wrapping the -argument in rlang::expr() (e.g. hidden_units = expr(num_preds * 2)).

    -

    An error is thrown if both penalty and dropout are specified for +

    An error is thrown if both penalty and dropout are specified for keras models.

    Engine Details

    Engines may have pre-set default arguments when executing the -model fit call. These can be changed by using the ... -argument to pass in the preferred values. For this type of +model fit call. For this type of model, the template of the fit calls are:

    keras classification

    diff --git a/docs/reference/model_fit.html b/docs/reference/model_fit.html
    index 7d0e22452..dffd15962 100644
    --- a/docs/reference/model_fit.html
    +++ b/docs/reference/model_fit.html
    @@ -164,27 +164,9 @@ 

    Details

    Examples

    # Keep the `x` matrix if the data are not too big. -spec_obj <- linear_reg(x = ifelse(.obs() < 500, TRUE, FALSE)) -spec_obj
    #> Linear Regression Model Specification (regression) -#> -#> Engine-Specific Arguments: -#> x = ifelse(.obs() < 500, TRUE, FALSE) -#>
    -fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars, engine = "lm") -fit_obj
    #> parsnip model object -#> -#> -#> Call: -#> stats::lm(formula = formula, data = data, x = ~ifelse(.obs() < -#> 500, TRUE, FALSE)) -#> -#> Coefficients: -#> (Intercept) cyl disp hp drat wt -#> 12.30337 -0.11144 0.01334 -0.02148 0.78711 -3.71530 -#> qsec vs am gear carb -#> 0.82104 0.31776 2.52023 0.65541 -0.19942 -#>
    -nrow(fit_obj$fit$x)
    #> [1] 32
    +spec_obj <- linear_reg(x = ifelse(.obs() < 500, TRUE, FALSE))
    #> Error in linear_reg(x = ifelse(.obs() < 500, TRUE, FALSE)): unused argument (x = ifelse(.obs() < 500, TRUE, FALSE))
    spec_obj
    #> Error in eval(expr, envir, enclos): object 'spec_obj' not found
    +fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars, engine = "lm")
    #> Error in fit(spec_obj, mpg ~ ., data = mtcars, engine = "lm"): object 'spec_obj' not found
    fit_obj
    #> Error in eval(expr, envir, enclos): object 'fit_obj' not found
    +nrow(fit_obj$fit$x)
    #> Error in nrow(fit_obj$fit$x): object 'fit_obj' not found
    -
    multinom_reg(mode = "classification", penalty = NULL, mixture = NULL,
    -  ...)
    +    
    multinom_reg(mode = "classification", penalty = NULL, mixture = NULL)
     
     # S3 method for multinom_reg
     update(object, penalty = NULL, mixture = NULL,
    -  fresh = FALSE, ...)
    + fresh = FALSE)

    Arguments

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object
    @@ -184,14 +183,6 @@

    Arg represents the proportion of regularization that is used for the L2 penalty (i.e. weight decay, or ridge regression) versus L1 (the lasso) (glmnet only).

    - -

    - - @@ -232,8 +223,7 @@

    glmnet

    diff --git a/docs/reference/nearest_neighbor.html b/docs/reference/nearest_neighbor.html
    index 02a21f33d..9f4116b5a 100644
    --- a/docs/reference/nearest_neighbor.html
    +++ b/docs/reference/nearest_neighbor.html
    @@ -44,7 +44,7 @@
     and the Euclidean distance with dist_power = 2.
     These arguments are converted to their specific names at the
     time that the model is fit. Other options and argument can be
    -set using the  ... slot. If left to their defaults
    +set using set_engine. If left to their defaults
     here (NULL), the values are taken from the underlying model
     functions. If parameters need to be modified, update() can be used
     in lieu of recreating the object from scratch." />
    @@ -157,7 +157,7 @@ 

    General Interface for K-Nearest Neighbor Models

    and the Euclidean distance with dist_power = 2.

    These arguments are converted to their specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to their defaults +set using set_engine. If left to their defaults here (NULL), the values are taken from the underlying model functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -165,7 +165,7 @@

    General Interface for K-Nearest Neighbor Models

    nearest_neighbor(mode = "unknown", neighbors = NULL,
    -  weight_func = NULL, dist_power = NULL, ...)
    + weight_func = NULL, dist_power = NULL)

    Arguments

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object
    @@ -192,14 +192,6 @@

    Arg

    - - - -
    dist_power

    A single number for the parameter used in calculating Minkowski distance.

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    @@ -222,8 +214,7 @@

    kknn (classification or regression)

    @@ -237,7 +228,10 @@ 

    See a

    Examples

    -
    nearest_neighbor()
    #> K-Nearest Neighbor Model Specification (unknown) +
    nearest_neighbor(neighbors = 11)
    #> K-Nearest Neighbor Model Specification (unknown) +#> +#> Main Arguments: +#> neighbors = 11 #>
    diff --git a/docs/reference/predict.model_fit.html b/docs/reference/predict.model_fit.html index 90090113a..bfa86efdf 100644 --- a/docs/reference/predict.model_fit.html +++ b/docs/reference/predict.model_fit.html @@ -159,8 +159,8 @@

    Arg type

    A single character value or NULL. Possible values -are "numeric", "class", "probs", "conf_int", "pred_int", or -"raw". When NULL, predict will choose an appropriate value +are "numeric", "class", "probs", "conf_int", "pred_int", "quantile", +or "raw". When NULL, predict will choose an appropriate value based on the model's mode.

    @@ -193,6 +193,9 @@

    Value

    the confidence level. In the case where intervals can be produces for class probabilities (or other non-scalar outputs), the columns will be named .pred_lower_classlevel and so on.

    +

    Quantile predictions return a tibble with a column .pred, which is +a list-column. Each list element contains a tibble with columns +.pred and .quantile (and perhaps other columns).

    Using type = "raw" with predict.model_fit (or using predict_raw) will return the unadulterated results of the prediction function.

    @@ -218,73 +221,25 @@

    Examp lm_model <- linear_reg() %>% - fit(mpg ~ ., data = mtcars %>% slice(11:32), engine = "lm") - + fit(mpg ~ ., data = mtcars %>% slice(11:32), engine = "lm")
    #> Error in if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) stop("spark objects can only be used with the formula interface to `fit` ", "with a spark data object.", call. = FALSE): missing value where TRUE/FALSE needed
    pred_cars <- mtcars %>% slice(1:10) %>% select(-mpg) -predict(lm_model, pred_cars)
    #> # A tibble: 10 x 1 -#> .pred -#> <dbl> -#>  1 23.4 -#>  2 23.3 -#>  3 27.6 -#>  4 21.5 -#>  5 17.6 -#>  6 21.6 -#>  7 13.9 -#>  8 21.7 -#>  9 25.6 -#> 10 17.1
    +predict(lm_model, pred_cars)
    #> Error in predict(lm_model, pred_cars): object 'lm_model' not found
    predict( lm_model, pred_cars, type = "conf_int", level = 0.90 -)
    #> # A tibble: 10 x 2 -#> .pred_lower .pred_upper -#> <dbl> <dbl> -#>  1 17.9 29.0 -#>  2 18.1 28.5 -#>  3 24.0 31.3 -#>  4 17.5 25.6 -#>  5 14.3 20.8 -#>  6 17.0 26.2 -#>  7 9.65 18.2 -#>  8 16.2 27.2 -#>  9 14.2 37.0 -#> 10 11.5 22.7
    +)
    #> Error in predict(lm_model, pred_cars, type = "conf_int", level = 0.9): object 'lm_model' not found
    predict( lm_model, pred_cars, type = "raw", opts = list(type = "terms") -)
    #> cyl disp hp drat wt qsec -#> 1 -0.001433177 -0.8113275 0.6303467 -0.06120265 2.4139815 -1.567729 -#> 2 -0.001433177 -0.8113275 0.6303467 -0.06120265 1.4488706 -0.736286 -#> 3 -0.009315653 -1.3336453 0.8557288 -0.05014798 3.5494061 1.624418 -#> 4 -0.001433177 0.1730406 0.6303467 0.12009386 0.1620561 2.856736 -#> 5 0.006449298 1.1975870 -0.2314083 0.10461733 -0.6895124 -0.736286 -#> 6 -0.001433177 -0.1584303 0.6966356 0.19084372 -0.7652074 4.014817 -#> 7 0.006449298 1.1975870 -1.1594522 0.09135173 -1.1815297 -2.488255 -#> 8 -0.009315653 -0.9449204 1.2667197 -0.01477305 0.2566748 3.688179 -#> 9 -0.009315653 -1.0041833 0.8292133 -0.06562451 0.4080647 7.993866 -#> 10 -0.001433177 -0.7349888 0.4579957 -0.06562451 -0.6895124 1.164155 -#> vs am gear carb -#> 1 0.2006406 2.88774 0.02512680 -0.2497240 -#> 2 0.2006406 2.88774 0.02512680 -0.2497240 -#> 3 -0.3511210 2.88774 0.02512680 0.4668753 -#> 4 -0.3511210 -2.40645 -0.06700481 0.4668753 -#> 5 0.2006406 -2.40645 -0.06700481 0.2280089 -#> 6 -0.3511210 -2.40645 -0.06700481 0.4668753 -#> 7 0.2006406 -2.40645 -0.06700481 -0.2497240 -#> 8 -0.3511210 -2.40645 0.02512680 0.2280089 -#> 9 -0.3511210 -2.40645 0.02512680 0.2280089 -#> 10 -0.3511210 -2.40645 0.02512680 -0.2497240 -#> attr(,"constant") -#> [1] 19.96364

    +)
    #> Error in predict(lm_model, pred_cars, type = "raw", opts = list(type = "terms")): object 'lm_model' not found
    rand_forest(mode = "unknown", mtry = NULL, trees = NULL,
    -  min_n = NULL, ...)
    +  min_n = NULL)
     
     # S3 method for rand_forest
     update(object, mtry = NULL, trees = NULL,
    -  min_n = NULL, fresh = FALSE, ...)
    + min_n = NULL, fresh = FALSE)

    Arguments

    @@ -190,14 +190,6 @@

    Arg

    - - - - @@ -217,9 +209,6 @@

    Details
  • R: "ranger" or "randomForest"

  • Spark: "spark"

  • -

    Main parameter arguments (and those in ...) can avoid -evaluation until the underlying function is executed by wrapping the -argument in rlang::expr() (e.g. mtry = expr(floor(sqrt(p)))).

    Note

    @@ -240,8 +229,7 @@

    ranger classification

    diff --git a/docs/reference/set_engine.html b/docs/reference/set_engine.html
    new file mode 100644
    index 000000000..c62ea625d
    --- /dev/null
    +++ b/docs/reference/set_engine.html
    @@ -0,0 +1,215 @@
    +
    +
    +
    +  
    +  
    +
    +
    +
    +Declare a computational engine and specific arguments — set_engine • parsnip
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +  
    +
    +  
    +    
    +
    + + + +
    + +
    +
    + + +
    + +

    set_engine is used to specify which package or system will be used +to fit the model, along with any arguments specific to that software.

    + +
    + +
    set_engine(object, engine, ...)
    + +

    Arguments

    +

    min_n

    An integer for the minimum number of data points in a node that are required for the node to be split further.

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object
    + + + + + + + + + + + + + +
    object

    A model specification.

    engine

    A character string for the software that should +be used to fit the model. This is highly dependent on the type +of model (e.g. linear regression, random forest, etc.).

    ...

    Any optional arguments associated with the chosen computational +engine. These are captured as quosures and can be varying().

    + +

    Value

    + +

    An updated model specification.

    + + +

    Examples

    +
    # First, set general arguments using the standardized names +mod <- + logistic_reg(mixture = 1/3) %>% + # now say how you want to fit the model and another other options + set_engine("glmnet", nlambda = 10) +translate(mod, engine = "glmnet")
    #> Logistic Regression Model Specification (classification) +#> +#> Main Arguments: +#> mixture = 1/3 +#> +#> Engine-Specific Arguments: +#> nlambda = 10 +#> +#> Computational engine: glmnet +#> +#> Model fit template: +#> glmnet::glmnet(x = missing_arg(), y = missing_arg(), weights = missing_arg(), +#> alpha = 1/3, nlambda = 10, family = "binomial")
    + + + + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn. + Site built by pkgdown. +

    +
    +
    + + + + + + + diff --git a/docs/reference/surv_reg.html b/docs/reference/surv_reg.html index 96ad45f21..3a67f9111 100644 --- a/docs/reference/surv_reg.html +++ b/docs/reference/surv_reg.html @@ -38,7 +38,7 @@ dist: The probability distribution of the outcome. This argument is converted to its specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to its default +set using set_engine. If left to its default here (NULL), the value is taken from the underlying model functions. If parameters need to be modified, this function can be used @@ -146,7 +146,7 @@

    General Interface for Parametric Survival Models

  • dist: The probability distribution of the outcome.

  • This argument is converted to its specific names at the time that the model is fit. Other options and argument can be -set using the ... slot. If left to its default +set using set_engine. If left to its default here (NULL), the value is taken from the underlying model functions.

    If parameters need to be modified, this function can be used @@ -154,10 +154,10 @@

    General Interface for Parametric Survival Models

    -
    surv_reg(mode = "regression", dist = NULL, ...)
    +    
    surv_reg(mode = "regression", dist = NULL)
     
     # S3 method for surv_reg
    -update(object, dist = NULL, fresh = FALSE, ...)
    +update(object, dist = NULL, fresh = FALSE)

    Arguments

    @@ -171,14 +171,6 @@

    Arg

    - - - - @@ -202,11 +194,32 @@

    Details

    Also, for the flexsurv::flexsurvfit engine, the typical strata function cannot be used. To achieve the same effect, the extra parameter roles can be used (as described above).

    +

    For surv_reg, the mode will always be "regression".

    The model can be created using the fit() function using the following engines:

    • R: "flexsurv", "survreg"

    +

    Engine Details

    + + +

    Engines may have pre-set default arguments when executing the +model fit call. For this type of +model, the template of the fit calls are:

    +

    flexsurv

    +

    +flexsurv::flexsurvreg(formula = missing_arg(), data = missing_arg(), 
    +    weights = missing_arg())
    +

    +

    survreg

    +

    +survival::survreg(formula = missing_arg(), data = missing_arg(), 
    +    weights = missing_arg(), model = TRUE)
    +

    +

    Note that model = TRUE is needed to produce quantile +predictions when there is a stratification variable and can be +overridden in other cases.

    +

    References

    Jackson, C. (2016). flexsurv: A Platform for Parametric Survival @@ -243,6 +256,8 @@

    Contents

  • Details
  • +
  • Engine Details
  • +
  • References
  • See also
  • diff --git a/docs/reference/varying_args.html b/docs/reference/varying_args.html index 5dd84b50e..63881571c 100644 --- a/docs/reference/varying_args.html +++ b/docs/reference/varying_args.html @@ -188,34 +188,16 @@

    Examp #> 2 trees FALSE one arg model_spec #> 3 min_n FALSE one arg model_spec
    rand_forest(others = list(sample.fraction = varying())) %>% - varying_args(id = "only others")
    #> # A tibble: 4 x 4 -#> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE only others model_spec -#> 2 trees FALSE only others model_spec -#> 3 min_n FALSE only others model_spec -#> 4 others TRUE only others model_spec
    + varying_args(id = "only others")
    #> Error in rand_forest(others = list(sample.fraction = varying())): unused argument (others = list(sample.fraction = varying()))
    rand_forest( others = list( strata = expr(Class), sampsize = c(varying(), varying()) ) ) %>% - varying_args(id = "add an expr")
    #> # A tibble: 4 x 4 -#> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE add an expr model_spec -#> 2 trees FALSE add an expr model_spec -#> 3 min_n FALSE add an expr model_spec -#> 4 others FALSE add an expr model_spec
    + varying_args(id = "add an expr")
    #> Error in rand_forest(others = list(strata = expr(Class), sampsize = c(varying(), varying()))): unused argument (others = list(strata = expr(Class), sampsize = c(varying(), varying())))
    rand_forest(others = list(classwt = c(class1 = 1, class2 = varying()))) %>% - varying_args(id = "list of values")
    #> # A tibble: 4 x 4 -#> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE list of values model_spec -#> 2 trees FALSE list of values model_spec -#> 3 min_n FALSE list of values model_spec -#> 4 others FALSE list of values model_spec
    + varying_args(id = "list of values")
    #> Error in rand_forest(others = list(classwt = c(class1 = 1, class2 = varying()))): unused argument (others = list(classwt = c(class1 = 1, class2 = varying())))
    @@ -175,17 +175,17 @@

    Classification Example

    #> # A tibble: 1 x 2#> .metric .estimate#> <chr> <dbl> -#> 1 roc_auc 0.825 +#> 1 roc_auc 0.822test_results %>% accuracy(truth = Status, estimate = `nnet class`)#> # A tibble: 1 x 2#> .metric .estimate#> <chr> <dbl> -#> 1 accuracy 0.801 +#> 1 accuracy 0.804test_results %>% conf_mat(truth = Status, estimate = `nnet class`)#> Truth#> Prediction bad good -#> bad 174 82 -#> good 139 718 +#> bad 176 81 +#> good 137 719 diff --git a/docs/articles/articles/Regression.html b/docs/articles/articles/Regression.html index 0ce7e3e3f..1b76a047e 100644 --- a/docs/articles/articles/Regression.html +++ b/docs/articles/articles/Regression.html @@ -40,7 +40,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/articles/articles/Scratch.html b/docs/articles/articles/Scratch.html index 007ae1264..e3397cf76 100644 --- a/docs/articles/articles/Scratch.html +++ b/docs/articles/articles/Scratch.html @@ -40,7 +40,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/articles/index.html b/docs/articles/index.html index cc27c2b37..4a511d2a9 100644 --- a/docs/articles/index.html +++ b/docs/articles/index.html @@ -69,7 +69,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/articles/parsnip_Intro.html b/docs/articles/parsnip_Intro.html index f9a75d439..c7b8fefef 100644 --- a/docs/articles/parsnip_Intro.html +++ b/docs/articles/parsnip_Intro.html @@ -40,7 +40,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/authors.html b/docs/authors.html index ef26bfe35..5a274292f 100644 --- a/docs/authors.html +++ b/docs/authors.html @@ -69,7 +69,7 @@ parsnip
    part of tidymodels - 0.0.0.9004 + 0.0.0.9005
    diff --git a/docs/index.html b/docs/index.html index dc0d8f38c..8c13f7d13 100644 --- a/docs/index.html +++ b/docs/index.html @@ -5,12 +5,12 @@ -A Common API to Modeling and analysis Functions • parsnip +A Common API to Modeling and Analysis Functions • parsnip - +
    dist

    A character string for the outcome distribution. "weibull" is the default.

    ...

    Other arguments to pass to the specific engine's -model fit function (see the Engine Details section below). This -should not include arguments defined by the main parameters to -this function. For the update function, the ellipses can -contain the primary arguments or any others.

    object