diff --git a/DESCRIPTION b/DESCRIPTION index 4d305e962..08001baaa 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: parsnip -Version: 0.0.2.9000 +Version: 0.0.3 Title: A Common API to Modeling and Analysis Functions Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc). Authors@R: c( diff --git a/NAMESPACE b/NAMESPACE index 37745ade0..c4014df9a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,6 +5,12 @@ S3method(fit_xy,model_spec) S3method(has_multi_predict,default) S3method(has_multi_predict,model_fit) S3method(has_multi_predict,workflow) +S3method(min_grid,boost_tree) +S3method(min_grid,linear_reg) +S3method(min_grid,logistic_reg) +S3method(min_grid,mars) +S3method(min_grid,multinom_reg) +S3method(min_grid,nearest_neighbor) S3method(multi_predict,"_C5.0") S3method(multi_predict,"_earth") S3method(multi_predict,"_elnet") @@ -50,8 +56,11 @@ S3method(print,svm_rbf) S3method(translate,boost_tree) S3method(translate,decision_tree) S3method(translate,default) +S3method(translate,linear_reg) +S3method(translate,logistic_reg) S3method(translate,mars) S3method(translate,mlp) +S3method(translate,multinom_reg) S3method(translate,nearest_neighbor) S3method(translate,rand_forest) S3method(translate,surv_reg) @@ -104,6 +113,13 @@ export(linear_reg) export(logistic_reg) export(make_classes) export(mars) +export(min_grid) +export(min_grid.boost_tree) +export(min_grid.linear_reg) +export(min_grid.logistic_reg) +export(min_grid.mars) +export(min_grid.multinom_reg) +export(min_grid.nearest_neighbor) export(mlp) export(model_printer) export(multi_predict) diff --git a/NEWS.md b/NEWS.md index 53875047a..dad6ada58 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,17 +1,26 @@ -# parsnip 0.0.2.9000 +# parsnip 0.0.3 + +Unplanned release based on CRAN requirements for Solaris. ## Breaking Changes - * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html). - * The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation). + * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html). + + * The mode needs to be declared for models that can be used for more than one mode prior to fitting and/or translation. + * For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`. + * For `glmnet` models, the full regularization path is always fit regardless of the value given to `penalty`. Previously, the model was fit with passing `penalty` to `glmnet`'s `lambda` argument and the model could only make predictions at those specific values. [(#195)](https://github.com/tidymodels/parsnip/issues/195) + ## New Features * `add_rowindex()` can create a column called `.row` to a data frame. * If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero. - * `nearest_neighbor` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized. + + * `nearest_neighbor()` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized. + + * A suite of internal functions were added to help with upcoming model tuning features. # parsnip 0.0.2 diff --git a/R/aaa.R b/R/aaa.R index a49415299..303579c25 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -19,10 +19,102 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { } # ------------------------------------------------------------------------------ +# min_grid generic - put here so that the generic shows up first in the man file + +#' Determine the minimum set of model fits +#' +#' `min_grid` determines exactly what models should be fit in order to +#' evaluate the entire set of tuning parameter combinations. This is for +#' internal use only and the API may change in the near future. +#' @param x A model specification. +#' @param grid A tibble with tuning parameter combinations. +#' @param ... Not currently used. +#' @return A tibble with the minimum tuning parameters to fit and an additional +#' list column with the parameter combinations used for prediction. +#' @keywords internal +#' @export +min_grid <- function(x, grid, ...) { + # x is a `model_spec` object from parsnip + # grid is a tibble of tuning parameter values with names + # matching the parameter names. + UseMethod("min_grid") +} + +# As an example, if we fit a boosted tree model and tune over +# trees = 1:20 and min_n = c(20, 30) +# we should only have to fit two models: +# +# trees = 20 & min_n = 20 +# trees = 20 & min_n = 30 +# +# The logic related to how this "mini grid" gets made is model-specific. +# +# To get the full set of predictions, we need to know, for each of these two +# models, what values of num_terms to give to the multi_predict() function. +# +# The current idea is to have a list column of the extra models for prediction. +# For the example above: +# +# # A tibble: 2 x 3 +# trees min_n .submodels +# +# 1 20 20 +# 2 20 30 +# +# and the .submodels would both be +# +# list(trees = 1:19) +# +# There are a lot of other things to consider in future versions like grids +# where there are multiple columns with the same name (maybe the results of +# a recipe) and so on. + +# ------------------------------------------------------------------------------ +# helper functions + +# Template for model results that do no have the sub-model feature +blank_submodels <- function(grid) { + grid %>% + dplyr::mutate(.submodels = map(1:nrow(grid), ~ list())) +} + +get_fixed_args <- function(info) { + # Get non-sub-model columns to iterate over + fixed_args <- info$name[!info$has_submodel] +} + +get_submodel_info <- function(spec, grid) { + param_info <- + get_from_env(paste0(class(spec)[1], "_args")) %>% + dplyr::filter(engine == spec$engine) %>% + dplyr::select(name = parsnip, has_submodel) + + # In case a recipe or other activity has grid parameter columns, + # add those to the results + grid_names <- names(grid) + is_mod_param <- grid_names %in% param_info$name + if (any(!is_mod_param)) { + param_info <- + param_info %>% + dplyr::bind_rows( + tibble::tibble(name = grid_names[!is_mod_param], + has_submodel = FALSE) + ) + } + param_info %>% dplyr::filter(name %in% grid_names) +} + + +# ------------------------------------------------------------------------------ +# nocov #' @importFrom utils globalVariables utils::globalVariables( c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', 'lab', 'original', 'predicted_label', 'prediction', 'value', 'type', - "neighbors") + "neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty", + "max_terms", "max_tree", "name", "num_terms", "penalty", "trees", + "sub_neighbors") ) + +# nocov end diff --git a/R/boost_tree.R b/R/boost_tree.R index 2963de1a4..ad9be1f2d 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -514,3 +514,41 @@ C50_by_tree <- function(tree, object, new_data, type, ...) { pred[, c(".row", "trees", nms)] } +# ------------------------------------------------------------------------------ + +#' @export +#' @export min_grid.boost_tree +#' @rdname min_grid +min_grid.boost_tree <- function(x, grid, ...) { + grid_names <- names(grid) + param_info <- get_submodel_info(x, grid) + + # No ability to do submodels? Finish here: + if (!any(param_info$has_submodel)) { + return(blank_submodels(grid)) + } + + fixed_args <- get_fixed_args(param_info) + + # For boosted trees, fit the model with the most trees (conditional on the + # other parameters) so that you can do predictions on the smaller models. + fit_only <- + grid %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(trees = max(trees, na.rm = TRUE)) %>% + dplyr::ungroup() + + # Add a column .submodels that is a list with what should be predicted + # by `multi_predict()` (assuming `predict()` has already been executed + # on the original value of 'trees') + min_grid_df <- + dplyr::full_join(fit_only %>% rename(max_tree = trees), grid, by = fixed_args) %>% + dplyr::filter(trees != max_tree) %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(.submodels = list(list(trees = trees))) %>% + dplyr::ungroup() %>% + dplyr::full_join(fit_only, grid, by = fixed_args) + + min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels) +} + diff --git a/R/linear_reg.R b/R/linear_reg.R index f9e5a8e74..28ebc4935 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -68,9 +68,9 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} #' -#' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. This -#' can have an effect on the model object results. When using the +#' For `glmnet` models, the full regularization path is always fit regardless +#' of the value given to `penalty`. Also, there is the option to pass +#' multiple values (or no values) to the `penalty` argument. When using the #' `predict()` method in these cases, the return value depends on #' the value of `penalty`. When using `predict()`, only a single #' value of the penalty can be used. When predicting on multiple @@ -138,6 +138,23 @@ print.linear_reg <- function(x, ...) { invisible(x) } + +#' @export +translate.linear_reg <- function(x, engine = x$engine, ...) { + x <- translate.default(x, engine, ...) + + if (engine == "glmnet") { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + x$method$fit$args$lambda <- NULL + # Since the `fit` infomration is gone for the penalty, we need to have an + # evaludated value for the parameter. + x$args$penalty <- rlang::eval_tidy(x$args$penalty) + } + + x +} + + # ------------------------------------------------------------------------------ #' @inheritParams update.boost_tree @@ -274,6 +291,11 @@ predict._elnet <- if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (is.null(penalty) & !is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } + object$spec$args$penalty <- check_penalty(penalty, object, multi) object$spec <- eval_args(object$spec) @@ -314,7 +336,12 @@ multi_predict._elnet <- object$spec <- eval_args(object$spec) if (is.null(penalty)) { - penalty <- object$fit$lambda + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (!is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } else { + penalty <- object$fit$lambda + } } pred <- predict._elnet(object, new_data = new_data, type = "raw", @@ -332,3 +359,37 @@ multi_predict._elnet <- names(pred) <- NULL tibble(.pred = pred) } + + +# ------------------------------------------------------------------------------ + +#' @export +#' @export min_grid.linear_reg +#' @rdname min_grid +min_grid.linear_reg <- function(x, grid, ...) { + + grid_names <- names(grid) + param_info <- get_submodel_info(x, grid) + + if (!any(param_info$has_submodel)) { + return(blank_submodels(grid)) + } + + fixed_args <- get_fixed_args(param_info) + + fit_only <- + grid %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(penalty = max(penalty, na.rm = TRUE)) %>% + dplyr::ungroup() + + min_grid_df <- + dplyr::full_join(fit_only %>% rename(max_penalty = penalty), grid, by = fixed_args) %>% + dplyr::filter(penalty != max_penalty) %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(.submodels = list(list(penalty = penalty))) %>% + dplyr::ungroup() %>% + dplyr::full_join(fit_only, grid, by = fixed_args) + + min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels) +} diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 7264680aa..d9d3c87a5 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -66,9 +66,9 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} #' -#' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. This -#' can have an effect on the model object results. When using the +#' For `glmnet` models, the full regularization path is always fit regardless +#' of the value given to `penalty`. Also, there is the option to pass +#' multiple values (or no values) to the `penalty` argument. When using the #' `predict()` method in these cases, the return value depends on #' the value of `penalty`. When using `predict()`, only a single #' value of the penalty can be used. When predicting on multiple @@ -137,6 +137,9 @@ print.logistic_reg <- function(x, ...) { invisible(x) } +#' @export +translate.logistic_reg <- translate.linear_reg + # ------------------------------------------------------------------------------ #' @inheritParams update.boost_tree @@ -235,7 +238,7 @@ organize_glmnet_prob <- function(x, object) { } # ------------------------------------------------------------------------------ -# glmnet call stack for linear regression using `predict` when object has +# glmnet call stack for logistic regression using `predict` when object has # classes "_lognet" and "model_fit" (for class predictions): # # predict() @@ -247,7 +250,7 @@ organize_glmnet_prob <- function(x, object) { # predict.lognet() -# glmnet call stack for linear regression using `multi_predict` when object has +# glmnet call stack for logistic regression using `multi_predict` when object has # classes "_lognet" and "model_fit" (for class predictions): # # multi_predict() @@ -262,10 +265,15 @@ organize_glmnet_prob <- function(x, object) { # ------------------------------------------------------------------------------ #' @export -predict._lognet <- function (object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { +predict._lognet <- function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (is.null(penalty) & !is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } + object$spec$args$penalty <- check_penalty(penalty, object, multi) object$spec <- eval_args(object$spec) @@ -286,8 +294,16 @@ multi_predict._lognet <- penalty <- eval_tidy(penalty) dots <- list(...) - if (is.null(penalty)) - penalty <- eval_tidy(object$fit$lambda) + + if (is.null(penalty)) { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (!is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } else { + penalty <- object$fit$lambda + } + } + dots$s <- penalty if (is.null(type)) @@ -330,7 +346,7 @@ multi_predict._lognet <- #' @export -predict_class._lognet <- function (object, new_data, ...) { +predict_class._lognet <- function(object, new_data, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) @@ -339,7 +355,7 @@ predict_class._lognet <- function (object, new_data, ...) { } #' @export -predict_classprob._lognet <- function (object, new_data, ...) { +predict_classprob._lognet <- function(object, new_data, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) @@ -348,7 +364,7 @@ predict_classprob._lognet <- function (object, new_data, ...) { } #' @export -predict_raw._lognet <- function (object, new_data, opts = list(), ...) { +predict_raw._lognet <- function(object, new_data, opts = list(), ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) @@ -356,3 +372,10 @@ predict_raw._lognet <- function (object, new_data, opts = list(), ...) { predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } + +# ------------------------------------------------------------------------------ + +#' @export +#' @export min_grid.logistic_reg +#' @rdname min_grid +min_grid.logistic_reg <- min_grid.linear_reg diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index d04f0100e..02b13f398 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -218,7 +218,7 @@ set_model_arg( parsnip = "penalty", original = "reg_param", func = list(pkg = "dials", fun = "penalty"), - has_submodel = TRUE + has_submodel = FALSE ) set_model_arg( diff --git a/R/mars.R b/R/mars.R index f83e56e29..22884852d 100644 --- a/R/mars.R +++ b/R/mars.R @@ -262,3 +262,36 @@ earth_by_terms <- function(num_terms, object, new_data, type, ...) { pred[[".row"]] <- 1:nrow(new_data) pred[, c(".row", "num_terms", nms)] } + +# ------------------------------------------------------------------------------ + +#' @export +#' @export min_grid.mars +#' @rdname min_grid +min_grid.mars <- function(x, grid, ...) { + + grid_names <- names(grid) + param_info <- get_submodel_info(x, grid) + + if (!any(param_info$has_submodel)) { + return(blank_submodels(grid)) + } + + fixed_args <- get_fixed_args(param_info) + + fit_only <- + grid %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(num_terms = max(num_terms, na.rm = TRUE)) %>% + dplyr::ungroup() + + min_grid_df <- + dplyr::full_join(fit_only %>% rename(max_terms = num_terms), grid, by = fixed_args) %>% + dplyr::filter(num_terms != max_terms) %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(.submodels = list(list(num_terms = num_terms))) %>% + dplyr::ungroup() %>% + dplyr::full_join(fit_only, grid, by = fixed_args) + + min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels) +} diff --git a/R/multinom_reg.R b/R/multinom_reg.R index b8bc0a479..9b6a6e768 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -57,9 +57,9 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} #' -#' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. This -#' can have an effect on the model object results. When using the +#' For `glmnet` models, the full regularization path is always fit regardless +#' of the value given to `penalty`. Also, there is the option to pass +#' multiple values (or no values) to the `penalty` argument. When using the #' `predict()` method in these cases, the return value depends on #' the value of `penalty`. When using `predict()`, only a single #' value of the penalty can be used. When predicting on multiple @@ -112,7 +112,7 @@ print.multinom_reg <- function(x, ...) { cat("Multinomial Regression Model Specification (", x$mode, ")\n\n", sep = "") model_printer(x, ...) - if(!is.null(x$method$fit$args)) { + if (!is.null(x$method$fit$args)) { cat("Model fit template:\n") print(show_call(x)) } @@ -120,6 +120,9 @@ print.multinom_reg <- function(x, ...) { invisible(x) } +#' @export +translate.multinom_reg <- translate.linear_reg + # ------------------------------------------------------------------------------ #' @inheritParams update.boost_tree @@ -188,7 +191,7 @@ organize_multnet_prob <- function(x, object) { } # ------------------------------------------------------------------------------ -# glmnet call stack for linear regression using `predict` when object has +# glmnet call stack for multinomial regression using `predict` when object has # classes "_multnet" and "model_fit" (for class predictions): # # predict() @@ -199,7 +202,7 @@ organize_multnet_prob <- function(x, object) { # predict.multnet() -# glmnet call stack for linear regression using `multi_predict` when object has +# glmnet call stack for multinomial regression using `multi_predict` when object has # classes "_multnet" and "model_fit" (for class predictions): # # multi_predict() @@ -217,6 +220,11 @@ organize_multnet_prob <- function(x, object) { predict._multnet <- function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (is.null(penalty) & !is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } + object$spec$args$penalty <- check_penalty(penalty, object, multi) object$spec <- eval_args(object$spec) @@ -242,14 +250,20 @@ multi_predict._multnet <- penalty <- eval_tidy(penalty) dots <- list(...) - if (is.null(penalty)) - penalty <- eval_tidy(object$fit$lambda) + if (is.null(penalty)) { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (!is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } else { + penalty <- object$fit$lambda + } + } dots$s <- penalty if (is.null(type)) type <- "class" if (!(type %in% c("class", "prob", "link", "raw"))) { - stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE) + stop("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE) } if (type == "prob") dots$type <- "response" @@ -290,19 +304,19 @@ multi_predict._multnet <- } #' @export -predict_class._multnet <- function (object, new_data, ...) { +predict_class._multnet <- function(object, new_data, ...) { object$spec <- eval_args(object$spec) predict_class.model_fit(object, new_data = new_data, ...) } #' @export -predict_classprob._multnet <- function (object, new_data, ...) { +predict_classprob._multnet <- function(object, new_data, ...) { object$spec <- eval_args(object$spec) predict_classprob.model_fit(object, new_data = new_data, ...) } #' @export -predict_raw._multnet <- function (object, new_data, opts = list(), ...) { +predict_raw._multnet <- function(object, new_data, opts = list(), ...) { object$spec <- eval_args(object$spec) predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } @@ -323,3 +337,10 @@ check_glmnet_lambda <- function(dat, object) { dat } + +# ------------------------------------------------------------------------------ + +#' @export +#' @export min_grid.multinom_reg +#' @rdname min_grid +min_grid.multinom_reg <- min_grid.linear_reg diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 7b7b16b3b..777126a01 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -44,7 +44,7 @@ set_pred( mode = "classification", type = "class", value = list( - pre = check_glmnet_lambda, + pre = NULL, post = organize_multnet_class, func = c(fun = "predict"), args = @@ -104,7 +104,7 @@ set_model_arg( parsnip = "penalty", original = "reg_param", func = list(pkg = "dials", fun = "penalty"), - has_submodel = TRUE + has_submodel = FALSE ) set_model_arg( diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index b28a3df46..6e29cf45d 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -213,8 +213,43 @@ multi_predict._train.kknn <- } knn_by_k <- function(k, object, new_data, type, ...) { - object$fit$call$ks <- k + object$fit$best.parameters$k <- k + predict(object, new_data = new_data, type = type, ...) %>% dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>% dplyr::select(.row, neighbors, dplyr::starts_with(".pred")) } + +# ------------------------------------------------------------------------------ + +#' @export +#' @export min_grid.nearest_neighbor +#' @rdname min_grid +min_grid.nearest_neighbor <- function(x, grid, ...) { + + grid_names <- names(grid) + param_info <- get_submodel_info(x, grid) + + if (!any(param_info$has_submodel)) { + return(blank_submodels(grid)) + } + + fixed_args <- get_fixed_args(param_info) + + fit_only <- + grid %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(neighbors = max(neighbors, na.rm = TRUE)) %>% + dplyr::ungroup() + + min_grid_df <- + dplyr::full_join(fit_only %>% rename(max_neighbor = neighbors), grid, by = fixed_args) %>% + dplyr::filter(neighbors != max_neighbor) %>% + dplyr::rename(sub_neighbors = neighbors, neighbors = max_neighbor) %>% + dplyr::group_by(!!!rlang::syms(fixed_args)) %>% + dplyr::summarize(.submodels = list(list(neighbors = sub_neighbors))) %>% + dplyr::ungroup() %>% + dplyr::full_join(fit_only, grid, by = fixed_args) + + min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels) +} diff --git a/_pkgdown.yml b/_pkgdown.yml index 644889998..69092b9e4 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -81,6 +81,8 @@ navbar: href: articles/articles/Classification.html - text: Making a parsnip model from scratch href: articles/articles/Scratch.html + - text: Evaluating submodels with the same model object + href: articles/articles/Submodels.html - text: News href: news/index.html - text: Reference diff --git a/docs/404.html b/docs/404.html index b120a1d2f..5ec6cfca1 100644 --- a/docs/404.html +++ b/docs/404.html @@ -52,6 +52,7 @@ gtag('config', 'UA-115082821-1'); + @@ -71,7 +72,7 @@ parsnip
part of tidymodels - 0.0.2 + 0.0.3
@@ -99,6 +100,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -115,6 +119,7 @@ +
    @@ -130,6 +135,7 @@

    Page not found (404)

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -141,11 +147,14 @@

    Page not found (404)

    Site built by pkgdown.

    +
    + + diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html index f5e50237a..2ec6a523f 100644 --- a/docs/articles/articles/Classification.html +++ b/docs/articles/articles/Classification.html @@ -40,7 +40,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -68,6 +68,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -86,6 +89,7 @@ +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -206,10 +210,12 @@

    Classification Example

    Site built by pkgdown.

    + + diff --git a/docs/articles/articles/Models.html b/docs/articles/articles/Models.html index 0588e29fa..0e60299a4 100644 --- a/docs/articles/articles/Models.html +++ b/docs/articles/articles/Models.html @@ -40,7 +40,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -68,6 +68,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -86,6 +89,7 @@ +
    + +
    + + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    + +
    +
    + + + + + + + + diff --git a/docs/reference/boost_tree.html b/docs/reference/boost_tree.html index 67f545c94..11ec5d621 100644 --- a/docs/reference/boost_tree.html +++ b/docs/reference/boost_tree.html @@ -76,6 +76,7 @@ gtag('config', 'UA-115082821-1'); + @@ -95,7 +96,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -123,6 +124,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -139,6 +143,7 @@ +
    @@ -171,7 +176,7 @@

    General Interface for Boosted Trees

    time that the model is fit. Other options and argument can be set using the set_engine() function. If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -181,7 +186,7 @@

    General Interface for Boosted Trees

    loss_reduction = NULL, sample_size = NULL) # S3 method for boost_tree -update(object, mtry = NULL, trees = NULL, +update(object, mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, sample_size = NULL, fresh = FALSE, ...) @@ -242,7 +247,7 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    @@ -255,9 +260,9 @@

    Details

    The data given to the function are not saved and are only used to determine the mode of the model. For boost_tree(), the possible modes are "regression" and "classification".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "xgboost", "C5.0"

    • +
    • R: "xgboost" (the default), "C5.0"

    • Spark: "spark"

    @@ -265,13 +270,13 @@

    Note

    For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

    @@ -309,7 +314,7 @@

    See also

    - +

    Examples

    @@ -328,12 +333,12 @@

    Examp #> Main Arguments: #> mtry = 10 #> min_n = 3 -#>
    update(model, mtry = 1)
    #> Boosted Tree Model Specification (unknown) +#>
    update(model, mtry = 1)
    #> Boosted Tree Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 #> min_n = 3 -#>
    update(model, mtry = 1, fresh = TRUE)
    #> Boosted Tree Model Specification (unknown) +#>
    update(model, mtry = 1, fresh = TRUE)
    #> Boosted Tree Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 @@ -360,6 +365,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -371,11 +377,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/check_empty_ellipse.html b/docs/reference/check_empty_ellipse.html index a1f7f9959..2e23494f0 100644 --- a/docs/reference/check_empty_ellipse.html +++ b/docs/reference/check_empty_ellipse.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -102,6 +103,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -162,6 +167,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -173,11 +179,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/check_times.html b/docs/reference/check_times.html index 89ff3919d..e68781981 100644 --- a/docs/reference/check_times.html +++ b/docs/reference/check_times.html @@ -59,6 +59,7 @@ gtag('config', 'UA-115082821-1'); + @@ -78,7 +79,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -106,6 +107,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -122,6 +126,7 @@ +
    @@ -195,8 +200,8 @@

    Details

    Examples

    -
    data(check_times) -str(check_times)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: +
    data(check_times) +str(check_times)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: #> $ package : chr "A3" "abbyyR" "abc" "abc.data" ... #> $ authors : int 1 1 1 1 5 3 2 1 4 6 ... #> $ imports : num 0 6 0 0 3 1 0 4 0 7 ... @@ -239,6 +244,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -250,11 +256,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/decision_tree.html b/docs/reference/decision_tree.html index cf3a7d3c8..63a8e9e62 100644 --- a/docs/reference/decision_tree.html +++ b/docs/reference/decision_tree.html @@ -70,6 +70,7 @@ gtag('config', 'UA-115082821-1'); + @@ -89,7 +90,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -117,6 +118,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -133,6 +137,7 @@ +
    @@ -159,7 +164,7 @@

    General Interface for Decision Tree Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -168,7 +173,7 @@

    General Interface for Decision Tree Models

    tree_depth = NULL, min_n = NULL) # S3 method for decision_tree -update(object, cost_complexity = NULL, +update(object, cost_complexity = NULL, tree_depth = NULL, min_n = NULL, fresh = FALSE, ...)

    Arguments

    @@ -205,15 +210,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "rpart" or "C5.0" (classification only)

    • +
    • R: "rpart" (the default) or "C5.0" (classification only)

    • Spark: "spark"

    Note that, for rpart models, but cost_complexity and @@ -226,13 +231,13 @@

    Note

    For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

    @@ -259,17 +264,17 @@

    spark classification

     sparklyr::ml_decision_tree_classifier(x = missing_arg(), formula = missing_arg(), 
    -    seed = sample.int(10^5, 1), type = "classification")
    +    seed = sample.int(10^5, 1))
     

    spark regression

     sparklyr::ml_decision_tree_classifier(x = missing_arg(), formula = missing_arg(), 
    -    seed = sample.int(10^5, 1), type = "regression")
    +    seed = sample.int(10^5, 1))
     

    See also

    - +

    Examples

    @@ -288,12 +293,12 @@

    Examp #> Main Arguments: #> cost_complexity = 10 #> min_n = 3 -#>
    update(model, cost_complexity = 1)
    #> Decision Tree Model Specification (unknown) +#>
    update(model, cost_complexity = 1)
    #> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 #> min_n = 3 -#>
    update(model, cost_complexity = 1, fresh = TRUE)
    #> Decision Tree Model Specification (unknown) +#>
    update(model, cost_complexity = 1, fresh = TRUE)
    #> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 @@ -318,6 +323,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -329,11 +335,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/descriptors.html b/docs/reference/descriptors.html index 4fcba8cb4..dfaba2ab3 100644 --- a/docs/reference/descriptors.html +++ b/docs/reference/descriptors.html @@ -59,6 +59,7 @@ gtag('config', 'UA-115082821-1'); + @@ -78,7 +79,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -106,6 +107,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -122,6 +126,7 @@ +
    @@ -134,7 +139,7 @@

    Data Set Characteristics Available when Fitting Models

    -

    When using the fit() functions there are some +

    When using the fit() functions there are some variables that will be available for use in arguments. For example, if the user would like to choose an argument value based on the current number of rows in a data set, the .obs() @@ -174,7 +179,7 @@

    Details
  • .y(): The known outcomes returned in the format given. Either a vector, matrix, or data frame.

  • .dat(): A data frame containing all of the predictors and the -outcomes. If fit_xy() was used, the outcomes are attached as the +outcomes. If fit_xy() was used, the outcomes are attached as the column, ..y.

  • For example, if you use the model formula Sepal.Width ~ . with the iris @@ -200,7 +205,7 @@

    Details

    To use these in a model fit, pass them to a model specification. The evaluation is delayed until the time when the -model is run via fit() (and the variables listed above are available). +model is run via fit() (and the variables listed above are available). For example:

         data("lending_club")
    @@ -221,6 +226,7 @@ 

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -232,11 +238,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/fit.html b/docs/reference/fit.html index 874199db2..6a48eb842 100644 --- a/docs/reference/fit.html +++ b/docs/reference/fit.html @@ -57,6 +57,7 @@ gtag('config', 'UA-115082821-1'); + @@ -76,7 +77,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -104,6 +105,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -120,6 +124,7 @@ +
    @@ -132,18 +137,18 @@

    Fit a Model Specification to a Dataset

    -

    fit() and fit_xy() take a model specification, translate the required +

    fit() and fit_xy() take a model specification, translate the required code by substituting arguments, and execute the model fit routine.

    # S3 method for model_spec
    -fit(object, formula = NULL, data = NULL,
    +fit(object, formula = NULL, data = NULL,
       control = fit_control(), ...)
     
     # S3 method for model_spec
    -fit_xy(object, x = NULL, y = NULL,
    +fit_xy(object, x = NULL, y = NULL,
       control = fit_control(), ...)

    Arguments

    @@ -206,21 +211,24 @@

    Value

    Details

    -

    fit() and fit_xy() substitute the current arguments in the model +

    fit() and fit_xy() substitute the current arguments in the model specification into the computational engine's code, checks them for validity, then fits the model using the data and the engine-specific code. Different model functions have different interfaces (e.g. formula or x/y) and these functions translate -between the interface used when fit() or fit_xy() were invoked and the one +between the interface used when fit() or fit_xy() were invoked and the one required by the underlying model.

    When possible, these functions attempt to avoid making copies of the data. For example, if the underlying model uses a formula and -fit() is invoked, the original data are references +fit() is invoked, the original data are references when the model is fit. However, if the underlying model uses something else, such as x/y, the formula is evaluated and the data are converted to the required format. In this case, any calls in the resulting model objects reference the temporary objects used to fit the model.

    +

    If the model engine has not been set, the model's default engine will be used +(as discussed on each model page). If the verbosity option of +fit_control() is greater than zero, a warning will be produced.

    See also

    @@ -231,28 +239,26 @@

    Examp
    # Although `glm()` only has a formula interface, different # methods for specifying the model can be used -library(dplyr)
    #> +library(dplyr)
    #> #> Attaching package: ‘dplyr’
    #> The following object is masked from ‘package:testthat’: #> #> matches
    #> The following objects are masked from ‘package:stats’: #> #> filter, lag
    #> The following objects are masked from ‘package:base’: #> -#> intersect, setdiff, setequal, union
    data("lending_club") - -lr_mod <- logistic_reg() +#> intersect, setdiff, setequal, union
    data("lending_club") lr_mod <- logistic_reg() using_formula <- lr_mod %>% set_engine("glm") %>% - fit(Class ~ funded_amnt + int_rate, data = lending_club) + fit(Class ~ funded_amnt + int_rate, data = lending_club) using_xy <- lr_mod %>% set_engine("glm") %>% - fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], + fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], y = lending_club$Class) using_formula
    #> parsnip model object @@ -296,6 +302,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -307,11 +314,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/fit_control.html b/docs/reference/fit_control.html index d1bfd9790..96240c9d1 100644 --- a/docs/reference/fit_control.html +++ b/docs/reference/fit_control.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -131,7 +136,7 @@

    Control the fit function

    -

    Options can be passed to the fit() function that control the output and +

    Options can be passed to the fit() function that control the output and computations

    @@ -154,7 +159,7 @@

    Arg catch

    A logical where a value of TRUE will evaluate -the model inside of try(, silent = TRUE). If the model fails, +the model inside of try(, silent = TRUE). If the model fails, an object is still returned (without an error) that inherits the class "try-error".

    @@ -178,6 +183,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -189,11 +195,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/get_model_env.html b/docs/reference/get_model_env.html new file mode 100644 index 000000000..813637906 --- /dev/null +++ b/docs/reference/get_model_env.html @@ -0,0 +1,259 @@ + + + + + + + + +Working with the parsnip model environment — get_model_env • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + + +
    + +
    +
    + + +
    + +

    These functions read and write to the environment where the package stores +information about model specifications.

    + +
    + +
    get_model_env()
    +
    +get_from_env(items)
    +
    +set_in_env(...)
    +
    +set_env_val(name, value)
    + +

    Arguments

    + + + + + + + + + + + + + + + + + + +
    items

    A character string of objects in the model environment.

    ...

    Named values that will be assigned to the model environment.

    name

    A single character value for a new symbol in the model environment.

    value

    A single value for a new value in the model environment.

    + +

    References

    + +

    "Making a parsnip model from scratch" +https://tidymodels.github.io/parsnip/articles/articles/Scratch.html

    + + +

    Examples

    +
    # Access the model data: +current_code <- get_model_env() +ls(envir = current_code)
    #> [1] "boost_tree" "boost_tree_args" +#> [3] "boost_tree_fit" "boost_tree_modes" +#> [5] "boost_tree_pkgs" "boost_tree_predict" +#> [7] "decision_tree" "decision_tree_args" +#> [9] "decision_tree_fit" "decision_tree_modes" +#> [11] "decision_tree_pkgs" "decision_tree_predict" +#> [13] "linear_reg" "linear_reg_args" +#> [15] "linear_reg_fit" "linear_reg_modes" +#> [17] "linear_reg_pkgs" "linear_reg_predict" +#> [19] "logistic_reg" "logistic_reg_args" +#> [21] "logistic_reg_fit" "logistic_reg_modes" +#> [23] "logistic_reg_pkgs" "logistic_reg_predict" +#> [25] "mars" "mars_args" +#> [27] "mars_fit" "mars_modes" +#> [29] "mars_pkgs" "mars_predict" +#> [31] "mlp" "mlp_args" +#> [33] "mlp_fit" "mlp_modes" +#> [35] "mlp_pkgs" "mlp_predict" +#> [37] "models" "modes" +#> [39] "multinom_reg" "multinom_reg_args" +#> [41] "multinom_reg_fit" "multinom_reg_modes" +#> [43] "multinom_reg_pkgs" "multinom_reg_predict" +#> [45] "nearest_neighbor" "nearest_neighbor_args" +#> [47] "nearest_neighbor_fit" "nearest_neighbor_modes" +#> [49] "nearest_neighbor_pkgs" "nearest_neighbor_predict" +#> [51] "null_model" "null_model_args" +#> [53] "null_model_fit" "null_model_modes" +#> [55] "null_model_pkgs" "null_model_predict" +#> [57] "rand_forest" "rand_forest_args" +#> [59] "rand_forest_fit" "rand_forest_modes" +#> [61] "rand_forest_pkgs" "rand_forest_predict" +#> [63] "surv_reg" "surv_reg_args" +#> [65] "surv_reg_fit" "surv_reg_modes" +#> [67] "surv_reg_pkgs" "surv_reg_predict" +#> [69] "svm_poly" "svm_poly_args" +#> [71] "svm_poly_fit" "svm_poly_modes" +#> [73] "svm_poly_pkgs" "svm_poly_predict" +#> [75] "svm_rbf" "svm_rbf_args" +#> [77] "svm_rbf_fit" "svm_rbf_modes" +#> [79] "svm_rbf_pkgs" "svm_rbf_predict"
    +
    +
    + +
    + + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    + +
    +
    + + + + + + + + diff --git a/docs/reference/has_multi_predict.html b/docs/reference/has_multi_predict.html new file mode 100644 index 000000000..a366f6203 --- /dev/null +++ b/docs/reference/has_multi_predict.html @@ -0,0 +1,248 @@ + + + + + + + + +Tools for models that predict on sub-models — has_multi_predict • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + + +
    + +
    +
    + + +
    + +

    has_multi_predict() tests to see if an object can make multiple +predictions on submodels from the same object. multi_predict_args() +returns the names of the argments to multi_predict() for this model +(if any).

    + +
    + +
    has_multi_predict(object, ...)
    +
    +# S3 method for default
    +has_multi_predict(object, ...)
    +
    +# S3 method for model_fit
    +has_multi_predict(object, ...)
    +
    +# S3 method for workflow
    +has_multi_predict(object, ...)
    +
    +multi_predict_args(object, ...)
    +
    +# S3 method for default
    +multi_predict_args(object, ...)
    +
    +# S3 method for model_fit
    +multi_predict_args(object, ...)
    +
    +# S3 method for workflow
    +multi_predict_args(object, ...)
    + +

    Arguments

    + + + + + + + + + + +
    object

    An object to test.

    ...

    Not currently used.

    + +

    Value

    + +

    has_multi_predict() returns single logical value while +multi_predict() returns a character vector of argument names (or NA +if none exist).

    + + +

    Examples

    +
    lm_model_idea <- linear_reg() %>% set_engine("lm") +has_multi_predict(lm_model_idea)
    #> [1] FALSE
    lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars) +has_multi_predict(lm_model_fit)
    #> [1] FALSE
    +multi_predict_args(lm_model_fit)
    #> [1] NA
    +library(kknn) + +knn_fit <- + nearest_neighbor(mode = "regression", neighbors = 5) %>% + set_engine("kknn") %>% + fit(mpg ~ ., mtcars) + +multi_predict_args(knn_fit)
    #> [1] "neighbors"
    +multi_predict(knn_fit, mtcars[1, -1], neighbors = 1:4)$.pred
    #> [[1]] +#> # A tibble: 4 x 2 +#> neighbors .pred +#> <int> <dbl> +#> 1 1 21.2 +#> 2 2 21.2 +#> 3 3 21.2 +#> 4 4 21.2 +#>
    +
    + +
    + + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    + +
    +
    + + + + + + + + diff --git a/docs/reference/index.html b/docs/reference/index.html index d54213b6b..7f4ebac1e 100644 --- a/docs/reference/index.html +++ b/docs/reference/index.html @@ -52,6 +52,7 @@ gtag('config', 'UA-115082821-1'); + @@ -71,7 +72,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -99,6 +100,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -115,6 +119,7 @@ +
    @@ -226,6 +231,12 @@

    add_rowindex()

    + +

    Add a column of row numbers to a data frame

    + +

    .cols() .preds() .obs() .lvls() .facts() .x() .y() .dat()

    @@ -269,7 +280,7 @@

    multi_predict()

    +

    multi_predict

    Model predictions across many sub-models

    @@ -329,6 +340,20 @@

    <

    Execution Time Data

    + + + +

    Developer Tools

    +

    + + + + + +

    set_new_model() set_model_mode() set_model_engine() set_model_arg() set_dependency() get_dependency() set_fit() get_fit() set_pred() get_pred_type() show_model_info() pred_value_template()

    + +

    Tools to Register Models

    +

    @@ -339,10 +364,12 @@

    Contents

  • Models
  • Infrastructure
  • Data
  • +
  • Developer Tools
  • +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -354,11 +381,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/keras_mlp.html b/docs/reference/keras_mlp.html index 229785727..e6a8811ba 100644 --- a/docs/reference/keras_mlp.html +++ b/docs/reference/keras_mlp.html @@ -57,6 +57,7 @@ gtag('config', 'UA-115082821-1'); + @@ -76,7 +77,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -104,6 +105,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -120,6 +124,7 @@ +
    @@ -139,7 +144,7 @@

    Simple interface to MLP models via keras

    keras_mlp(x, y, hidden_units = 5, decay = 0, dropout = 0,
    -  epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3),
    +  epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3),
       ...)

    Arguments

    @@ -203,6 +208,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -214,11 +220,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/lending_club.html b/docs/reference/lending_club.html index 70120d370..b79f08723 100644 --- a/docs/reference/lending_club.html +++ b/docs/reference/lending_club.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -102,6 +103,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -158,8 +163,8 @@

    Details

    Examples

    -
    data(lending_club) -str(lending_club)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: +
    data(lending_club) +str(lending_club)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: #> $ funded_amnt : int 16100 32000 10000 16800 3500 10000 11000 15000 6000 20000 ... #> $ term : Factor w/ 2 levels "term_36","term_60": 1 2 1 2 1 1 1 1 1 2 ... #> $ int_rate : num 13.99 11.99 16.29 13.67 7.39 ... @@ -200,6 +205,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -211,11 +217,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/linear_reg.html b/docs/reference/linear_reg.html index c9e5dd46d..8f16085b0 100644 --- a/docs/reference/linear_reg.html +++ b/docs/reference/linear_reg.html @@ -68,6 +68,7 @@ gtag('config', 'UA-115082821-1'); + @@ -87,7 +88,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -115,6 +116,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -131,6 +135,7 @@ +
    @@ -155,7 +160,7 @@

    General Interface for Linear Regression Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -163,7 +168,7 @@

    General Interface for Linear Regression Models

    linear_reg(mode = "regression", penalty = NULL, mixture = NULL)
     
     # S3 method for linear_reg
    -update(object, penalty = NULL, mixture = NULL,
    +update(object, penalty = NULL, mixture = NULL,
       fresh = FALSE, ...)

    Arguments

    @@ -200,7 +205,7 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    @@ -209,9 +214,9 @@

    Details

    The data given to the function are not saved and are only used to determine the mode of the model. For linear_reg(), the mode will always be "regression".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "lm" or "glmnet"

    • +
    • R: "lm" (the default) or "glmnet"

    • Stan: "stan"

    • Spark: "spark"

    • keras: "keras"

    • @@ -221,13 +226,13 @@

      Note

      For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

      @@ -262,11 +267,11 @@

      -

      When using glmnet models, there is the option to pass -multiple values (or no values) to the penalty argument. This -can have an effect on the model object results. When using the -predict() method in these cases, the return value depends on -the value of penalty. When using predict(), only a single +

      For glmnet models, the full regularization path is always fit regardless +of the value given to penalty. Also, there is the option to pass +multiple values (or no values) to the penalty argument. When using the +predict() method in these cases, the return value depends on +the value of penalty. When using predict(), only a single value of the penalty can be used. When predicting on multiple penalties, the multi_predict() function can be used. It returns a tibble with a list column called .pred that contains @@ -280,7 +285,7 @@

      See also

      - +

      Examples

      @@ -296,12 +301,12 @@

      Examp #> Main Arguments: #> penalty = 10 #> mixture = 0.1 -#>
      update(model, penalty = 1)
      #> Linear Regression Model Specification (regression) +#>
      update(model, penalty = 1)
      #> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
      update(model, penalty = 1, fresh = TRUE)
      #> Linear Regression Model Specification (regression) +#>
      update(model, penalty = 1, fresh = TRUE)
      #> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 @@ -326,6 +331,7 @@

      Contents

      +

      parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

      @@ -337,11 +343,14 @@

      Contents

      Site built by pkgdown.

      +
      + + diff --git a/docs/reference/logistic_reg.html b/docs/reference/logistic_reg.html index cdfc30736..1f5268311 100644 --- a/docs/reference/logistic_reg.html +++ b/docs/reference/logistic_reg.html @@ -68,6 +68,7 @@ gtag('config', 'UA-115082821-1'); + @@ -87,7 +88,7 @@ parsnip
      part of tidymodels - 0.0.2 + 0.0.3
      @@ -115,6 +116,9 @@
    • Making a parsnip model from scratch
    • +
    • + Evaluating submodels with the same model object +

  • @@ -131,6 +135,7 @@ +
    @@ -155,7 +160,7 @@

    General Interface for Logistic Regression Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -163,7 +168,7 @@

    General Interface for Logistic Regression Models

    logistic_reg(mode = "classification", penalty = NULL, mixture = NULL)
     
     # S3 method for logistic_reg
    -update(object, penalty = NULL, mixture = NULL,
    +update(object, penalty = NULL, mixture = NULL,
       fresh = FALSE, ...)

    Arguments

    @@ -200,16 +205,16 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    For logistic_reg(), the mode will always be "classification".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "glm" or "glmnet"

    • +
    • R: "glm" (the default) or "glmnet"

    • Stan: "stan"

    • Spark: "spark"

    • keras: "keras"

    • @@ -219,13 +224,13 @@

      Note

      For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

      @@ -261,11 +266,11 @@

      -

      When using glmnet models, there is the option to pass -multiple values (or no values) to the penalty argument. This -can have an effect on the model object results. When using the -predict() method in these cases, the return value depends on -the value of penalty. When using predict(), only a single +

      For glmnet models, the full regularization path is always fit regardless +of the value given to penalty. Also, there is the option to pass +multiple values (or no values) to the penalty argument. When using the +predict() method in these cases, the return value depends on +the value of penalty. When using predict(), only a single value of the penalty can be used. When predicting on multiple penalties, the multi_predict() function can be used. It returns a tibble with a list column called .pred that contains @@ -280,7 +285,7 @@

      See also

      - +

      Examples

      @@ -296,12 +301,12 @@

      Examp #> Main Arguments: #> penalty = 10 #> mixture = 0.1 -#>
      update(model, penalty = 1)
      #> Logistic Regression Model Specification (classification) +#>
      update(model, penalty = 1)
      #> Logistic Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
      update(model, penalty = 1, fresh = TRUE)
      #> Logistic Regression Model Specification (classification) +#>
      update(model, penalty = 1, fresh = TRUE)
      #> Logistic Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 @@ -326,6 +331,7 @@

      Contents

      +

      parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

      @@ -337,11 +343,14 @@

      Contents

      Site built by pkgdown.

      +
      + + diff --git a/docs/reference/make_classes.html b/docs/reference/make_classes.html index 5cc4b4bca..aa7fb9b3e 100644 --- a/docs/reference/make_classes.html +++ b/docs/reference/make_classes.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
      part of tidymodels - 0.0.2 + 0.0.3
      @@ -102,6 +103,9 @@
    • Making a parsnip model from scratch
    • +
    • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -162,6 +167,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -173,11 +179,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/mars.html b/docs/reference/mars.html index 09a4882e4..acf85eb74 100644 --- a/docs/reference/mars.html +++ b/docs/reference/mars.html @@ -71,6 +71,7 @@ gtag('config', 'UA-115082821-1'); + @@ -90,7 +91,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -118,6 +119,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -134,6 +138,7 @@ +
    @@ -161,7 +166,7 @@

    General Interface for MARS

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -170,7 +175,7 @@

    General Interface for MARS

    prune_method = NULL) # S3 method for mars -update(object, num_terms = NULL, prod_degree = NULL, +update(object, num_terms = NULL, prod_degree = NULL, prune_method = NULL, fresh = FALSE, ...)

    Arguments

    @@ -206,15 +211,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "earth"

    • +
    • R: "earth" (the default)

    Engine Details

    @@ -239,7 +244,7 @@

    See also

    - +

    Examples

    @@ -253,12 +258,12 @@

    Examp #> Main Arguments: #> num_terms = 10 #> prune_method = none -#>
    update(model, num_terms = 1)
    #> MARS Model Specification (unknown) +#>
    update(model, num_terms = 1)
    #> MARS Model Specification (unknown) #> #> Main Arguments: #> num_terms = 1 #> prune_method = none -#>
    update(model, num_terms = 1, fresh = TRUE)
    #> MARS Model Specification (unknown) +#>
    update(model, num_terms = 1, fresh = TRUE)
    #> MARS Model Specification (unknown) #> #> Main Arguments: #> num_terms = 1 @@ -281,6 +286,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -292,11 +298,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/min_grid.html b/docs/reference/min_grid.html new file mode 100644 index 000000000..be6333af9 --- /dev/null +++ b/docs/reference/min_grid.html @@ -0,0 +1,223 @@ + + + + + + + + +Determine the minimum set of model fits — min_grid • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + + +
    + +
    +
    + + +
    + +

    min_grid determines exactly what models should be fit in order to +evaluate the entire set of tuning parameter combinations. This is for +internal use only.

    + +
    + +
    min_grid(x, grid, ...)
    +
    +# S3 method for boost_tree
    +min_grid(x, grid, ...)
    +
    +# S3 method for linear_reg
    +min_grid(x, grid, ...)
    +
    +# S3 method for logistic_reg
    +min_grid(x, grid, ...)
    +
    +# S3 method for mars
    +min_grid(x, grid, ...)
    +
    +# S3 method for multinom_reg
    +min_grid(x, grid, ...)
    +
    +# S3 method for nearest_neighbor
    +min_grid(x, grid, ...)
    + +

    Arguments

    + + + + + + + + + + + + + + +
    x

    A model specification.

    grid

    A tibble with tuning parameter combinations.

    ...

    Not currently used.

    + +

    Value

    + +

    A tibble with the minimum tuning parameters to fit and an additional +list column with the parameter combinations used for prediction.

    + + +
    + +
    + + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    + +
    +
    + + + + + + + + diff --git a/docs/reference/mlp.html b/docs/reference/mlp.html index 460c32d47..67f392e51 100644 --- a/docs/reference/mlp.html +++ b/docs/reference/mlp.html @@ -71,6 +71,7 @@ gtag('config', 'UA-115082821-1'); + @@ -90,7 +91,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -118,6 +119,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -134,6 +138,7 @@ +
    @@ -170,7 +175,7 @@

    General Interface for Single Layer Neural Network

    dropout = NULL, epochs = NULL, activation = NULL) # S3 method for mlp -update(object, hidden_units = NULL, penalty = NULL, +update(object, hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE, ...) @@ -220,7 +225,7 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    @@ -230,15 +235,15 @@

    Details time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (see above), the values are taken from the underlying model -functions. One exception is hidden_units when nnet::nnet is used; that +functions. One exception is hidden_units when nnet::nnet is used; that function's size argument has no default so a value of 5 units will be used. Also, unless otherwise specified, the linout argument to -nnet::nnet() will be set to TRUE when a regression model is created. -If parameters need to be modified, update() can be used +nnet::nnet() will be set to TRUE when a regression model is created. +If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "nnet"

    • +
    • R: "nnet" (the default)

    • keras: "keras"

    An error is thrown if both penalty and dropout are specified for @@ -271,7 +276,7 @@

    See also

    - +

    Examples

    @@ -290,12 +295,12 @@

    Examp #> Main Arguments: #> hidden_units = 10 #> dropout = 0.3 -#>

    update(model, hidden_units = 2)
    #> Single Layer Neural Network Specification (unknown) +#>
    update(model, hidden_units = 2)
    #> Single Layer Neural Network Specification (unknown) #> #> Main Arguments: #> hidden_units = 2 #> dropout = 0.3 -#>
    update(model, hidden_units = 2, fresh = TRUE)
    #> Single Layer Neural Network Specification (unknown) +#>
    update(model, hidden_units = 2, fresh = TRUE)
    #> Single Layer Neural Network Specification (unknown) #> #> Main Arguments: #> hidden_units = 2 @@ -318,6 +323,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -329,11 +335,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/model_fit.html b/docs/reference/model_fit.html index aa4317313..c5cf3eaf2 100644 --- a/docs/reference/model_fit.html +++ b/docs/reference/model_fit.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -168,7 +173,7 @@

    Examp # Keep the `x` matrix if the data are not too big. spec_obj <- linear_reg() %>% - set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) + set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) spec_obj

    #> Linear Regression Model Specification (regression) #> #> Engine-Specific Arguments: @@ -176,7 +181,7 @@

    Examp #> #> Computational engine: lm #>

    -fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) +fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) fit_obj
    #> parsnip model object #> #> @@ -190,7 +195,7 @@

    Examp #> qsec vs am gear carb #> 0.82104 0.31776 2.52023 0.65541 -0.19942 #>

    -nrow(fit_obj$fit$x)
    #> [1] 32
    +nrow(fit_obj$fit$x)
    #> [1] 32
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -215,11 +221,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/model_printer.html b/docs/reference/model_printer.html index 06d3c8441..a0cc9076c 100644 --- a/docs/reference/model_printer.html +++ b/docs/reference/model_printer.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -162,6 +167,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -173,11 +179,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/model_spec.html b/docs/reference/model_spec.html index 9b3772e75..268ec0597 100644 --- a/docs/reference/model_spec.html +++ b/docs/reference/model_spec.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ + + + + diff --git a/docs/reference/multi_predict.html b/docs/reference/multi_predict.html index 3477f2790..a167d4ac2 100644 --- a/docs/reference/multi_predict.html +++ b/docs/reference/multi_predict.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -102,6 +103,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -137,7 +142,35 @@

    Model predictions across many sub-models

    multi_predict(object, ...)
     
     # S3 method for default
    -multi_predict(object, ...)
    +multi_predict(object, ...) + +# S3 method for _xgb.Booster +multi_predict(object, new_data, type = NULL, + trees = NULL, ...) + +# S3 method for _C5.0 +multi_predict(object, new_data, type = NULL, + trees = NULL, ...) + +# S3 method for _elnet +multi_predict(object, new_data, type = NULL, + penalty = NULL, ...) + +# S3 method for _lognet +multi_predict(object, new_data, type = NULL, + penalty = NULL, ...) + +# S3 method for _earth +multi_predict(object, new_data, type = NULL, + num_terms = NULL, ...) + +# S3 method for _multnet +multi_predict(object, new_data, type = NULL, + penalty = NULL, ...) + +# S3 method for _train.kknn +multi_predict(object, new_data, type = NULL, + neighbors = NULL, ...)

    Arguments

    @@ -151,6 +184,33 @@

    Arg

    + + + + + + + + + + + + + + + + + + + + + + + +

    Optional arguments to pass to predict.model_fit(type = "raw") such as type.

    new_data

    A rectangular data object, such as a data frame.

    type

    A single character value or NULL. Possible values +are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", +or "raw". When NULL, predict() will choose an appropriate value +based on the model's mode.

    trees

    An integer vector for the number of trees in the ensemble.

    penalty

    An numeric vector of penalty values.

    num_terms

    An integer vector for the number of MARS terms to retain.

    neighbors

    An integer vector for the number of nearest neighbors.

    Value

    @@ -172,6 +232,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -183,11 +244,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/multinom_reg.html b/docs/reference/multinom_reg.html index 3f6668f8c..782ad6c87 100644 --- a/docs/reference/multinom_reg.html +++ b/docs/reference/multinom_reg.html @@ -68,6 +68,7 @@ gtag('config', 'UA-115082821-1'); + @@ -87,7 +88,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -115,6 +116,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -131,6 +135,7 @@ +
    @@ -155,7 +160,7 @@

    General Interface for Multinomial Regression Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -163,7 +168,7 @@

    General Interface for Multinomial Regression Models

    multinom_reg(mode = "classification", penalty = NULL, mixture = NULL)
     
     # S3 method for multinom_reg
    -update(object, penalty = NULL, mixture = NULL,
    +update(object, penalty = NULL, mixture = NULL,
       fresh = FALSE, ...)

    Arguments

    @@ -200,16 +205,16 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    For multinom_reg(), the mode will always be "classification".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "glmnet"

    • +
    • R: "glmnet" (the default)

    • Stan: "stan"

    • keras: "keras"

    @@ -218,13 +223,13 @@

    Note

    For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

    @@ -250,11 +255,11 @@

    -

    When using glmnet models, there is the option to pass -multiple values (or no values) to the penalty argument. This -can have an effect on the model object results. When using the -predict() method in these cases, the return value depends on -the value of penalty. When using predict(), only a single +

    For glmnet models, the full regularization path is always fit regardless +of the value given to penalty. Also, there is the option to pass +multiple values (or no values) to the penalty argument. When using the +predict() method in these cases, the return value depends on +the value of penalty. When using predict(), only a single value of the penalty can be used. When predicting on multiple penalties, the multi_predict() function can be used. It returns a tibble with a list column called .pred that contains @@ -262,7 +267,7 @@

    See also

    - +

    Examples

    @@ -278,12 +283,12 @@

    Examp #> Main Arguments: #> penalty = 10 #> mixture = 0.1 -#>
    update(model, penalty = 1)
    #> Multinomial Regression Model Specification (classification) +#>
    update(model, penalty = 1)
    #> Multinomial Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
    update(model, penalty = 1, fresh = TRUE)
    #> Multinomial Regression Model Specification (classification) +#>
    update(model, penalty = 1, fresh = TRUE)
    #> Multinomial Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 @@ -308,6 +313,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -319,11 +325,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/nearest_neighbor.html b/docs/reference/nearest_neighbor.html index ca04511c1..a47780a2c 100644 --- a/docs/reference/nearest_neighbor.html +++ b/docs/reference/nearest_neighbor.html @@ -71,6 +71,7 @@ gtag('config', 'UA-115082821-1'); + @@ -90,7 +91,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -118,6 +119,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -134,6 +138,7 @@ +
    @@ -161,7 +166,7 @@

    General Interface for K-Nearest Neighbor Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -181,7 +186,8 @@

    Arg neighbors

    A single integer for the number of neighbors -to consider (often called k).

    +to consider (often called k). For kknn, a value of 5 +is used if neighbors is not specified.

    weight_func @@ -199,9 +205,9 @@

    Arg

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "kknn"

    • +
    • R: "kknn" (the default)

    Note

    @@ -221,12 +227,12 @@

    kknn (classification or regression)

     kknn::train.kknn(formula = missing_arg(), data = missing_arg(), 
    -    kmax = missing_arg())
    +    ks = 5)
     

    See also

    - +

    Examples

    @@ -256,6 +262,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -267,11 +274,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/null_model.html b/docs/reference/null_model.html index d433a77df..064e0b64b 100644 --- a/docs/reference/null_model.html +++ b/docs/reference/null_model.html @@ -57,6 +57,7 @@ gtag('config', 'UA-115082821-1'); + @@ -76,7 +77,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -104,6 +105,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -120,6 +124,7 @@ +
    @@ -153,7 +158,7 @@

    Arg

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    • R: "parsnip"

    @@ -175,7 +180,7 @@

    See also

    - +

    Examples

    @@ -199,6 +204,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -210,11 +216,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/nullmodel.html b/docs/reference/nullmodel.html index edc9778e0..473af96a9 100644 --- a/docs/reference/nullmodel.html +++ b/docs/reference/nullmodel.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -142,10 +147,10 @@

    Fit a simple, non-informative model

    nullmodel(x = NULL, y, ...) # S3 method for nullmodel -print(x, ...) +print(x, ...) # S3 method for nullmodel -predict(object, new_data = NULL, type = NULL, ...) +predict(object, new_data = NULL, type = NULL, ...)

    Arguments

    @@ -207,13 +212,13 @@

    Details

    Examples

    -outcome <- factor(sample(letters[1:2], +outcome <- factor(sample(letters[1:2], size = 100, - prob = c(.1, .9), + prob = c(.1, .9), replace = TRUE)) useless <- nullmodel(y = outcome) useless
    #> Null Regression Model -#> Predicted Value: b
    predict(useless, matrix(NA, nrow = 5))
    #> [1] b b b b b +#> Predicted Value: b
    predict(useless, matrix(NA, nrow = 5))
    #> [1] b b b b b #> Levels: a b
    @@ -232,6 +237,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -243,11 +249,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/predict.model_fit.html b/docs/reference/predict.model_fit.html index 7f0a5b2e3..c727e0747 100644 --- a/docs/reference/predict.model_fit.html +++ b/docs/reference/predict.model_fit.html @@ -57,6 +57,7 @@ gtag('config', 'UA-115082821-1'); + @@ -76,7 +77,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -104,6 +105,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -120,6 +124,7 @@ +
    @@ -133,14 +138,14 @@

    Model predictions

    Apply a model to create different types of predictions. -predict() can be used for all types of models and used the +predict() can be used for all types of models and used the "type" argument for more specificity.

    # S3 method for model_fit
    -predict(object, new_data, type = NULL,
    -  opts = list(), ...)
    +predict(object, new_data, type = NULL, + opts = list(), ...)

    Arguments

  • @@ -157,7 +162,7 @@

    Arg

    @@ -213,16 +218,16 @@

    Value

    appear in names and 2) vectors are never returned but type-specific prediction functions.

    When the model fit failed and the error was captured, the -predict() function will return the same structure as above but +predict() function will return the same structure as above but filled with missing values. This does not currently work for multivariate models.

    Details

    -

    If "type" is not supplied to predict(), then a choice +

    If "type" is not supplied to predict(), then a choice is made (type = "numeric" for regression models and type = "class" for classification).

    -

    predict() is designed to provide a tidy result (see "Value" +

    predict() is designed to provide a tidy result (see "Value" section below) in a tibble output format.

    When using type = "conf_int" and type = "pred_int", the options level and std_error can be used. The latter is a logical for an @@ -230,54 +235,54 @@

    Details

    Examples

    -
    library(dplyr) +
    library(dplyr) lm_model <- linear_reg() %>% set_engine("lm") %>% - fit(mpg ~ ., data = mtcars %>% slice(11:32)) + fit(mpg ~ ., data = mtcars %>% slice(11:32)) pred_cars <- mtcars %>% slice(1:10) %>% select(-mpg) -predict(lm_model, pred_cars)
    #> # A tibble: 10 x 1 +predict(lm_model, pred_cars)
    #> # A tibble: 10 x 1 #> .pred -#> <dbl> -#> 1 23.4 -#> 2 23.3 -#> 3 27.6 -#> 4 21.5 -#> 5 17.6 -#> 6 21.6 -#> 7 13.9 -#> 8 21.7 -#> 9 25.6 -#> 10 17.1
    -predict( +#> <dbl> +#> 1 23.4 +#> 2 23.3 +#> 3 27.6 +#> 4 21.5 +#> 5 17.6 +#> 6 21.6 +#> 7 13.9 +#> 8 21.7 +#> 9 25.6 +#> 10 17.1
    +predict( lm_model, pred_cars, type = "conf_int", level = 0.90 -)
    #> # A tibble: 10 x 2 +)
    #> # A tibble: 10 x 2 #> .pred_lower .pred_upper -#> <dbl> <dbl> -#> 1 17.9 29.0 -#> 2 18.1 28.5 -#> 3 24.0 31.3 -#> 4 17.5 25.6 -#> 5 14.3 20.8 -#> 6 17.0 26.2 -#> 7 9.65 18.2 -#> 8 16.2 27.2 -#> 9 14.2 37.0 -#> 10 11.5 22.7
    -predict( +#> <dbl> <dbl> +#> 1 17.9 29.0 +#> 2 18.1 28.5 +#> 3 24.0 31.3 +#> 4 17.5 25.6 +#> 5 14.3 20.8 +#> 6 17.0 26.2 +#> 7 9.65 18.2 +#> 8 16.2 27.2 +#> 9 14.2 37.0 +#> 10 11.5 22.7
    +predict( lm_model, pred_cars, type = "raw", - opts = list(type = "terms") + opts = list(type = "terms") )
    #> cyl disp hp drat wt qsec #> 1 -0.001433177 -0.8113275 0.6303467 -0.06120265 2.4139815 -1.567729 #> 2 -0.001433177 -0.8113275 0.6303467 -0.06120265 1.4488706 -0.736286 @@ -318,6 +323,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -329,11 +335,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/rand_forest.html b/docs/reference/rand_forest.html index dd273eebe..496636a4f 100644 --- a/docs/reference/rand_forest.html +++ b/docs/reference/rand_forest.html @@ -69,6 +69,7 @@ gtag('config', 'UA-115082821-1'); + @@ -88,7 +89,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -116,6 +117,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -132,6 +136,7 @@ +
    @@ -157,7 +162,7 @@

    General Interface for Random Forest Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -166,7 +171,7 @@

    General Interface for Random Forest Models

    min_n = NULL) # S3 method for rand_forest -update(object, mtry = NULL, trees = NULL, +update(object, mtry = NULL, trees = NULL, min_n = NULL, fresh = FALSE, ...)
  • Arguments

    @@ -204,15 +209,15 @@

    Arg

    - +
    type

    A single character value or NULL. Possible values are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", -or "raw". When NULL, predict() will choose an appropriate value +or "raw". When NULL, predict() will choose an appropriate value based on the model's mode.

    ...

    Not used for update().

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "ranger" or "randomForest"

    • +
    • R: "ranger" (the default) or "randomForest"

    • Spark: "spark"

    @@ -220,7 +225,7 @@

    Note

    For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor @@ -274,7 +279,7 @@

    See also

    - +

    Examples

    @@ -293,12 +298,12 @@

    Examp #> Main Arguments: #> mtry = 10 #> min_n = 3 -#>

    update(model, mtry = 1)
    #> Random Forest Model Specification (unknown) +#>
    update(model, mtry = 1)
    #> Random Forest Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 #> min_n = 3 -#>
    update(model, mtry = 1, fresh = TRUE)
    #> Random Forest Model Specification (unknown) +#>
    update(model, mtry = 1, fresh = TRUE)
    #> Random Forest Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 @@ -323,6 +328,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -334,11 +340,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/reexports.html b/docs/reference/reexports.html index c1cd5361b..703390577 100644 --- a/docs/reference/reexports.html +++ b/docs/reference/reexports.html @@ -61,6 +61,7 @@ gtag('config', 'UA-115082821-1'); + @@ -80,7 +81,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -108,6 +109,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -124,6 +128,7 @@ +
    @@ -139,7 +144,7 @@

    Objects exported from other packages

    These objects are imported from other packages. Follow the links below to see their documentation.

    -
    generics

    fit, fit_xy, varying_args

    +
    generics

    fit, fit_xy, varying_args

    magrittr

    %>%

    @@ -157,6 +162,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -168,11 +174,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/rpart_train.html b/docs/reference/rpart_train.html index aa288e7bd..f68b429f1 100644 --- a/docs/reference/rpart_train.html +++ b/docs/reference/rpart_train.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -202,6 +207,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -213,11 +219,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/set_args.html b/docs/reference/set_args.html index 9c8bf428d..544462587 100644 --- a/docs/reference/set_args.html +++ b/docs/reference/set_args.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -197,6 +202,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -208,11 +214,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/set_engine.html b/docs/reference/set_engine.html index 4aaee273a..d5ade5c26 100644 --- a/docs/reference/set_engine.html +++ b/docs/reference/set_engine.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -196,6 +201,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -207,11 +213,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/set_new_model.html b/docs/reference/set_new_model.html new file mode 100644 index 000000000..b21284fa4 --- /dev/null +++ b/docs/reference/set_new_model.html @@ -0,0 +1,356 @@ + + + + + + + + +Tools to Register Models — set_new_model • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + + +
    + +
    +
    + + +
    + +

    These functions are similar to constructors and can be used to validate +that there are no conflicts with the underlying model structures used by the +package.

    + +
    + +
    set_new_model(model)
    +
    +set_model_mode(model, mode)
    +
    +set_model_engine(model, mode, eng)
    +
    +set_model_arg(model, eng, parsnip, original, func, has_submodel)
    +
    +set_dependency(model, eng, pkg)
    +
    +get_dependency(model)
    +
    +set_fit(model, mode, eng, value)
    +
    +get_fit(model)
    +
    +set_pred(model, mode, eng, type, value)
    +
    +get_pred_type(model, type)
    +
    +show_model_info(model)
    +
    +pred_value_template(pre = NULL, post = NULL, func, ...)
    + +

    Arguments

    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    model

    A single character string for the model type (e.g. +"rand_forest", etc).

    mode

    A single character string for the model mode (e.g. "regression").

    eng

    A single character string for the model engine.

    parsnip

    A single character string for the "harmonized" argument name +that parsnip exposes.

    original

    A single character string for the argument name that +underlying model function uses.

    func

    A named character vector that describes how to call +a function. func should have elements pkg and fun. The +former is optional but is recommended and the latter is +required. For example, c(pkg = "stats", fun = "lm") would be +used to invoke the usual linear regression function. In some +cases, it is helpful to use c(fun = "predict") when using a +package's predict method.

    has_submodel

    A single logical for whether the argument +can make predictions on mutiple submodels at once.

    pkg

    An options character string for a package name.

    value

    A list that conforms to the fit_obj or pred_obj description +above, depending on context.

    type

    A single character value for the type of prediction. Possible +values are: class, conf_int, numeric, pred_int, prob, quantile, +and raw.

    pre, post

    Optional functions for pre- and post-processing of prediction +results.

    ...

    Optional arguments that should be passed into the args slot for +prediction objects.

    arg

    A single character string for the model argument name.

    fit_obj

    A list with elements interface, protect, +func and defaults. See the package vignette "Making a +parsnip model from scratch".

    pred_obj

    A list with elements pre, post, func, and args. +See the package vignette "Making a parsnip model from scratch".

    + +

    Details

    + +

    These functions are available for users to add their +own models or engines (in package or otherwise) so that they can +be accessed using parsnip. This are more thoroughly documented +on the package web site (see references below).

    +

    In short, parsnip stores an environment object that contains +all of the information and code about how models are used (e.g. +fitting, predicting, etc). These functions can be used to add +models to that environment as well as helper functions that can +be used to makes sure that the model data is in the right +format.

    +

    check_model_exists() checks the model value and ensures that the model has +already been registered. check_model_doesnt_exist() checks the model value +and also checks to see if it is novel in the environment.

    + +

    References

    + +

    "Making a parsnip model from scratch" +https://tidymodels.github.io/parsnip/articles/articles/Scratch.html

    + + +

    Examples

    +
    # set_new_model("shallow_learning_model") + +# Show the information about a model: +show_model_info("rand_forest")
    #> Information for `rand_forest` +#> modes: unknown, classification, regression +#> +#> engines: +#> classification: randomForest, ranger, spark +#> regression: randomForest, ranger, spark +#> +#> arguments: +#> ranger: +#> mtry --> mtry +#> trees --> num.trees +#> min_n --> min.node.size +#> randomForest: +#> mtry --> mtry +#> trees --> ntree +#> min_n --> nodesize +#> spark: +#> mtry --> feature_subset_strategy +#> trees --> num_trees +#> min_n --> min_instances_per_node +#> +#> fit modules: +#> engine mode +#> ranger classification +#> ranger regression +#> randomForest classification +#> randomForest regression +#> spark classification +#> spark regression +#> +#> prediction modules: +#> mode engine methods +#> classification randomForest class, prob, raw +#> classification ranger class, conf_int, prob, raw +#> classification spark class, prob +#> regression randomForest numeric, raw +#> regression ranger conf_int, numeric, raw +#> regression spark numeric +#>
    +
    + +
    + + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    + +
    +
    + + + + + + + + diff --git a/docs/reference/show_call.html b/docs/reference/show_call.html index bdf75f3d4..138699371 100644 --- a/docs/reference/show_call.html +++ b/docs/reference/show_call.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -102,6 +103,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -162,6 +167,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -173,11 +179,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/surv_reg.html b/docs/reference/surv_reg.html index 8cf118c1f..1b236a988 100644 --- a/docs/reference/surv_reg.html +++ b/docs/reference/surv_reg.html @@ -66,6 +66,7 @@ gtag('config', 'UA-115082821-1'); + @@ -85,7 +86,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -113,6 +114,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -129,6 +133,7 @@ +
    @@ -159,7 +164,7 @@

    General Interface for Parametric Survival Models

    surv_reg(mode = "regression", dist = NULL)
     
     # S3 method for surv_reg
    -update(object, dist = NULL, fresh = FALSE, ...)
    +update(object, dist = NULL, fresh = FALSE, ...)

    Arguments

    @@ -185,7 +190,7 @@

    Arg

    - +
    ...

    Not used for update().

    Not used for update().

    @@ -195,15 +200,15 @@

    Details to determine the mode of the model. For surv_reg(),the mode will always be "regression".

    Since survival models typically involve censoring (and require the use of -survival::Surv() objects), the fit() function will require that the +survival::Surv() objects), the fit() function will require that the survival model be specified via the formula interface.

    Also, for the flexsurv::flexsurvfit engine, the typical strata function cannot be used. To achieve the same effect, the extra parameter roles can be used (as described above).

    For surv_reg(), the mode will always be "regression".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "flexsurv", "survreg"

    • +
    • R: "flexsurv", "survival" (the default)

    Engine Details

    @@ -217,7 +222,7 @@

    survreg

    +

    survival

     survival::survreg(formula = missing_arg(), data = missing_arg(), 
         weights = missing_arg(), model = TRUE)
    @@ -233,7 +238,7 @@ 

    R

    See also

    - +

    Examples

    @@ -249,7 +254,7 @@

    Examp #> #> Main Arguments: #> dist = weibull -#>

    update(model, dist = "lnorm")
    #> Parametric Survival Regression Model Specification (regression) +#>
    update(model, dist = "lnorm")
    #> Parametric Survival Regression Model Specification (regression) #> #> Main Arguments: #> dist = lnorm @@ -274,6 +279,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -285,11 +291,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/svm_poly.html b/docs/reference/svm_poly.html index 5a443acc3..bb439af99 100644 --- a/docs/reference/svm_poly.html +++ b/docs/reference/svm_poly.html @@ -70,6 +70,7 @@ gtag('config', 'UA-115082821-1'); + @@ -89,7 +90,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -117,6 +118,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -133,6 +137,7 @@ +
    @@ -159,7 +164,7 @@

    General interface for polynomial support vector machines

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -168,7 +173,7 @@

    General interface for polynomial support vector machines

    scale_factor = NULL, margin = NULL) # S3 method for svm_poly -update(object, cost = NULL, degree = NULL, +update(object, cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, fresh = FALSE, ...)

    Arguments

    @@ -209,15 +214,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "kernlab"

    • +
    • R: "kernlab" (the default)

    Engine Details

    @@ -238,7 +243,7 @@

    See also

    - +

    Examples

    @@ -257,12 +262,12 @@

    Examp #> Main Arguments: #> cost = 10 #> scale_factor = 0.1 -#>
    update(model, cost = 1)
    #> Polynomial Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1)
    #> Polynomial Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 #> scale_factor = 0.1 -#>
    update(model, cost = 1, fresh = TRUE)
    #> Polynomial Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1, fresh = TRUE)
    #> Polynomial Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 @@ -285,6 +290,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -296,11 +302,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/svm_rbf.html b/docs/reference/svm_rbf.html index ae59ef985..6f7a4b325 100644 --- a/docs/reference/svm_rbf.html +++ b/docs/reference/svm_rbf.html @@ -70,6 +70,7 @@ gtag('config', 'UA-115082821-1'); + @@ -89,7 +90,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -117,6 +118,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -133,6 +137,7 @@ +
    @@ -159,7 +164,7 @@

    General interface for radial basis function support vector machines

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -168,7 +173,7 @@

    General interface for radial basis function support vector machines

    margin = NULL) # S3 method for svm_rbf -update(object, cost = NULL, rbf_sigma = NULL, +update(object, cost = NULL, rbf_sigma = NULL, margin = NULL, fresh = FALSE, ...)

    Arguments

    @@ -205,15 +210,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "kernlab"

    • +
    • R: "kernlab" (the default)

    Engine Details

    @@ -234,7 +239,7 @@

    See also

    - +

    Examples

    @@ -253,12 +258,12 @@

    Examp #> Main Arguments: #> cost = 10 #> rbf_sigma = 0.1 -#>
    update(model, cost = 1)
    #> Radial Basis Function Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1)
    #> Radial Basis Function Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 #> rbf_sigma = 0.1 -#>
    update(model, cost = 1, fresh = TRUE)
    #> Radial Basis Function Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1, fresh = TRUE)
    #> Radial Basis Function Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 @@ -281,6 +286,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -292,11 +298,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/tidy.model_fit.html b/docs/reference/tidy.model_fit.html index 80ca1eb6b..f20ed312c 100644 --- a/docs/reference/tidy.model_fit.html +++ b/docs/reference/tidy.model_fit.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -102,6 +103,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -166,6 +171,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -177,11 +183,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/translate.html b/docs/reference/translate.html index 4fd50043e..ddc26d238 100644 --- a/docs/reference/translate.html +++ b/docs/reference/translate.html @@ -57,6 +57,7 @@ gtag('config', 'UA-115082821-1'); + @@ -76,7 +77,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -104,6 +105,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -120,6 +124,7 @@ +
    @@ -157,7 +162,7 @@

    Details

    translate() produces a template call that lacks the specific argument values (such as data, etc). These are filled in once -fit() is called with the specifics of the data for the model. +fit() is called with the specifics of the data for the model. The call may also include varying arguments if these are in the specification.

    It does contain the resolved argument names that are specific to @@ -183,7 +188,7 @@

    Examp #> #> Model fit template: #> glmnet::glmnet(x = missing_arg(), y = missing_arg(), weights = missing_arg(), -#> lambda = 0.01, family = "gaussian")

    +#> family = "gaussian")
    # `penalty` not applicable for this model. translate(lm_spec, engine = "lm")
    #> Linear Regression Model Specification (regression) #> @@ -231,6 +236,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -242,11 +248,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/type_sum.model_spec.html b/docs/reference/type_sum.model_spec.html index edb82df48..942928721 100644 --- a/docs/reference/type_sum.model_spec.html +++ b/docs/reference/type_sum.model_spec.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -179,6 +184,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -190,11 +196,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/varying.html b/docs/reference/varying.html index 6d638f577..61b3cfe2b 100644 --- a/docs/reference/varying.html +++ b/docs/reference/varying.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -102,6 +103,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -146,6 +151,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -157,11 +163,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/varying_args.model_spec.html b/docs/reference/varying_args.model_spec.html index 8e46e6a4a..4ae6746a5 100644 --- a/docs/reference/varying_args.model_spec.html +++ b/docs/reference/varying_args.model_spec.html @@ -57,6 +57,7 @@ gtag('config', 'UA-115082821-1'); + @@ -76,7 +77,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -104,6 +105,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -120,6 +124,7 @@ +
    @@ -132,20 +137,20 @@

    Determine varying arguments

    -

    varying_args() takes a model specification or a recipe and returns a tibble +

    varying_args() takes a model specification or a recipe and returns a tibble of information on all possible varying arguments and whether or not they are actually varying.

    # S3 method for model_spec
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)
     
     # S3 method for recipe
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)
     
     # S3 method for step
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)

    Arguments

    @@ -182,50 +187,50 @@

    Details

    Examples

    # List all possible varying args for the random forest spec -rand_forest() %>% varying_args()
    #> # A tibble: 3 x 4 +rand_forest() %>% varying_args()
    #> # A tibble: 3 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec
    # mtry is now recognized as varying -rand_forest(mtry = varying()) %>% varying_args()
    #> # A tibble: 3 x 4 +rand_forest(mtry = varying()) %>% varying_args()
    #> # A tibble: 3 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry TRUE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry TRUE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec
    # Even engine specific arguments can vary rand_forest() %>% set_engine("ranger", sample.fraction = varying()) %>% - varying_args()
    #> # A tibble: 4 x 4 + varying_args()
    #> # A tibble: 4 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec -#> 4 sample.fraction TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec +#> 4 sample.fraction TRUE rand_forest model_spec
    # List only the arguments that actually vary rand_forest() %>% set_engine("ranger", sample.fraction = varying()) %>% - varying_args(full = FALSE)
    #> # A tibble: 1 x 4 + varying_args(full = FALSE)
    #> # A tibble: 1 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 sample.fraction TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 sample.fraction TRUE rand_forest model_spec
    rand_forest() %>% set_engine( "randomForest", strata = Class, sampsize = varying() ) %>% - varying_args()
    #> # A tibble: 5 x 4 + varying_args()
    #> # A tibble: 5 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec -#> 4 strata FALSE rand_forest model_spec -#> 5 sampsize TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec +#> 4 strata FALSE rand_forest model_spec +#> 5 sampsize TRUE rand_forest model_spec
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -254,11 +260,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/wa_churn.html b/docs/reference/wa_churn.html index 85a0e4832..cfb1ac282 100644 --- a/docs/reference/wa_churn.html +++ b/docs/reference/wa_churn.html @@ -55,6 +55,7 @@ gtag('config', 'UA-115082821-1'); + @@ -74,7 +75,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -102,6 +103,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -118,6 +122,7 @@ +
    @@ -158,8 +163,8 @@

    Details

    Examples

    -
    data(wa_churn) -str(wa_churn)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 7043 obs. of 20 variables: +
    data(wa_churn) +str(wa_churn)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 7043 obs. of 20 variables: #> $ churn : Factor w/ 2 levels "Yes","No": 2 2 1 2 1 1 2 2 1 2 ... #> $ female : num 1 0 0 0 1 1 0 1 1 0 ... #> $ senior_citizen : int 0 0 0 0 0 0 0 0 0 0 ... @@ -197,6 +202,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -208,11 +214,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/docs/reference/xgb_train.html b/docs/reference/xgb_train.html index a0fa57a55..454ea325a 100644 --- a/docs/reference/xgb_train.html +++ b/docs/reference/xgb_train.html @@ -56,6 +56,7 @@ gtag('config', 'UA-115082821-1'); + @@ -75,7 +76,7 @@ parsnip
    part of tidymodels - 0.0.2 + 0.0.3
    @@ -103,6 +104,9 @@
  • Making a parsnip model from scratch
  • +
  • + Evaluating submodels with the same model object +
  • @@ -119,6 +123,7 @@ +
    @@ -204,6 +209,7 @@

    Contents

    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    @@ -215,11 +221,14 @@

    Contents

    Site built by pkgdown.

    +
    + + diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index b58c1d631..791ffb650 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -105,9 +105,9 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} -When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. This -can have an effect on the model object results. When using the +For \code{glmnet} models, the full regularization path is always fit regardless +of the value given to \code{penalty}. Also, there is the option to pass +multiple values (or no values) to the \code{penalty} argument. When using the \code{predict()} method in these cases, the return value depends on the value of \code{penalty}. When using \code{predict()}, only a single value of the penalty can be used. When predicting on multiple diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index 43aa599e9..137da1094 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -103,9 +103,9 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} -When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. This -can have an effect on the model object results. When using the +For \code{glmnet} models, the full regularization path is always fit regardless +of the value given to \code{penalty}. Also, there is the option to pass +multiple values (or no values) to the \code{penalty} argument. When using the \code{predict()} method in these cases, the return value depends on the value of \code{penalty}. When using \code{predict()}, only a single value of the penalty can be used. When predicting on multiple diff --git a/man/min_grid.Rd b/man/min_grid.Rd new file mode 100644 index 000000000..e79250a2b --- /dev/null +++ b/man/min_grid.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa.R, R/boost_tree.R, R/linear_reg.R, +% R/logistic_reg.R, R/mars.R, R/multinom_reg.R, R/nearest_neighbor.R +\name{min_grid} +\alias{min_grid} +\alias{min_grid.boost_tree} +\alias{min_grid.linear_reg} +\alias{min_grid.logistic_reg} +\alias{min_grid.mars} +\alias{min_grid.multinom_reg} +\alias{min_grid.nearest_neighbor} +\title{Determine the minimum set of model fits} +\usage{ +min_grid(x, grid, ...) + +\method{min_grid}{boost_tree}(x, grid, ...) + +\method{min_grid}{linear_reg}(x, grid, ...) + +\method{min_grid}{logistic_reg}(x, grid, ...) + +\method{min_grid}{mars}(x, grid, ...) + +\method{min_grid}{multinom_reg}(x, grid, ...) + +\method{min_grid}{nearest_neighbor}(x, grid, ...) +} +\arguments{ +\item{x}{A model specification.} + +\item{grid}{A tibble with tuning parameter combinations.} + +\item{...}{Not currently used.} +} +\value{ +A tibble with the minimum tuning parameters to fit and an additional +list column with the parameter combinations used for prediction. +} +\description{ +\code{min_grid} determines exactly what models should be fit in order to +evaluate the entire set of tuning parameter combinations. This is for +internal use only and the API may change in the near future. +} +\keyword{internal} diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index 6f2b4af05..6b711058a 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -94,9 +94,9 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} -When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. This -can have an effect on the model object results. When using the +For \code{glmnet} models, the full regularization path is always fit regardless +of the value given to \code{penalty}. Also, there is the option to pass +multiple values (or no values) to the \code{penalty} argument. When using the \code{predict()} method in these cases, the return value depends on the value of \code{penalty}. When using \code{predict()}, only a single value of the penalty can be used. When predicting on multiple diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index 27f15ee9b..7cbe52750 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -207,3 +207,36 @@ test_that('default engine', { ) expect_true(inherits(fit$fit, "xgb.Booster")) }) + +test_that('boosted tree grid reduction', { + reg_grid <- expand.grid(trees = 1:3, learn_rate = (1:5)/5) + reg_grid_smol <- min_grid(boost_tree() %>% set_engine("xgboost"), reg_grid) + + expect_equal(reg_grid_smol$trees, rep(3, 5)) + expect_equal(reg_grid_smol$learn_rate, (1:5)/5) + for (i in 1:nrow(reg_grid_smol)) { + expect_equal(reg_grid_smol$.submodels[[i]], list(trees = 1:2)) + } + + reg_ish_grid <- expand.grid(trees = 1:3, learn_rate = (1:5)/5)[-3,] + reg_ish_grid_smol <- min_grid(boost_tree() %>% set_engine("xgboost"), reg_ish_grid) + + expect_equal(reg_ish_grid_smol$trees, c(2, rep(3, 4))) + expect_equal(reg_ish_grid_smol$learn_rate, (1:5)/5) + expect_equal(reg_ish_grid_smol$.submodels[[1]], list(trees = 1)) + for (i in 2:nrow(reg_ish_grid_smol)) { + expect_equal(reg_ish_grid_smol$.submodels[[i]], list(trees = 1:2)) + } + + reg_grid_extra <- expand.grid(trees = 1:3, learn_rate = (1:5)/5, blah = 10:12) + reg_grid_extra_smol <- min_grid(boost_tree() %>% set_engine("xgboost"), reg_grid_extra) + + expect_equal(reg_grid_extra_smol$trees, rep(3, 15)) + expect_equal(reg_grid_extra_smol$learn_rate, rep((1:5)/5, each = 3)) + expect_equal(reg_grid_extra_smol$blah, rep(10:12, 5)) + for (i in 1:nrow(reg_grid_extra_smol)) { + expect_equal(reg_grid_extra_smol$.submodels[[i]], list(trees = 1:2)) + } + +}) + diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index e2830d96a..20416717f 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -76,7 +76,6 @@ test_that('primary arguments', { x = expr(missing_arg()), y = expr(missing_arg()), weights = expr(missing_arg()), - lambda = new_empty_quosure(1), family = "gaussian" ) ) diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index 46ff3658d..fb2365a13 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -67,10 +67,10 @@ test_that('glmnet prediction, single lambda', { y = iris$Sepal.Length ) - uni_pred <- c(5.05124049139868, 4.87103404621362, 4.91028250633598, 4.9399094532023, - 5.08728178043569) + uni_pred <- c(5.05125589060219, 4.86977761622526, 4.90912345599309, 4.93931874108359, + 5.08755154547758) - expect_equal(uni_pred, predict(res_xy, iris[1:5, num_pred])$.pred) + expect_equal(uni_pred, predict(res_xy, iris[1:5, num_pred])$.pred, tolerance = 0.0001) res_form <- fit( iris_basic, @@ -79,10 +79,10 @@ test_that('glmnet prediction, single lambda', { control = ctrl ) - form_pred <- c(5.24228948237804, 5.09448280355765, 5.15636527125752, 5.12592317615935, - 5.26930099973607) + form_pred <- c(5.23960117346944, 5.08769210344022, 5.15129212608077, 5.12000510716518, + 5.26736239856889) - expect_equal(form_pred, predict(res_form, iris[1:5,])$.pred) + expect_equal(form_pred, predict(res_form, iris[1:5,])$.pred, tolerance = 0.0001) }) @@ -132,7 +132,8 @@ test_that('glmnet prediction, multiple lambda', { as.data.frame(mult_pred), multi_predict(res_xy, new_data = iris[1:5, num_pred], lambda = lams) %>% unnest() %>% - as.data.frame() + as.data.frame(), + tolerance = 0.0001 ) res_form <- fit( @@ -176,7 +177,8 @@ test_that('glmnet prediction, multiple lambda', { as.data.frame(form_pred), multi_predict(res_form, new_data = iris[1:5, ], lambda = lams) %>% unnest() %>% - as.data.frame() + as.data.frame(), + tolerance = 0.0001 ) }) @@ -249,7 +251,7 @@ test_that('submodel prediction', { ) reg_fit <- - linear_reg(penalty = c(0, 0.01, 0.1)) %>% + linear_reg() %>% set_engine("glmnet") %>% fit(mpg ~ ., data = mtcars[-(1:4), ]) @@ -274,12 +276,6 @@ test_that('error traps', { skip_if_not_installed("glmnet") - expect_error( - linear_reg(penalty = .1) %>% - set_engine("glmnet") %>% - fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% - predict(mtcars[-(1:4), ], penalty = .2) - ) expect_error( linear_reg() %>% set_engine("glmnet") %>% @@ -295,3 +291,36 @@ test_that('error traps', { }) + +test_that('glmnet grid reduction', { + reg_grid <- expand.grid(penalty = 1:3, mixture = (1:5)/5) + reg_grid_smol <- min_grid(linear_reg() %>% set_engine("glmnet"), reg_grid) + + expect_equal(reg_grid_smol$penalty, rep(3, 5)) + expect_equal(reg_grid_smol$mixture, (1:5)/5) + for (i in 1:nrow(reg_grid_smol)) { + expect_equal(reg_grid_smol$.submodels[[i]], list(penalty = 1:2)) + } + + reg_ish_grid <- expand.grid(penalty = 1:3, mixture = (1:5)/5)[-3,] + reg_ish_grid_smol <- min_grid(linear_reg() %>% set_engine("glmnet"), reg_ish_grid) + + expect_equal(reg_ish_grid_smol$penalty, c(2, rep(3, 4))) + expect_equal(reg_ish_grid_smol$mixture, (1:5)/5) + expect_equal(reg_ish_grid_smol$.submodels[[1]], list(penalty = 1)) + for (i in 2:nrow(reg_ish_grid_smol)) { + expect_equal(reg_ish_grid_smol$.submodels[[i]], list(penalty = 1:2)) + } + + reg_grid_extra <- expand.grid(penalty = 1:3, mixture = (1:5)/5, blah = 10:12) + reg_grid_extra_smol <- min_grid(linear_reg() %>% set_engine("glmnet"), reg_grid_extra) + + expect_equal(reg_grid_extra_smol$penalty, rep(3, 15)) + expect_equal(reg_grid_extra_smol$mixture, rep((1:5)/5, each = 3)) + expect_equal(reg_grid_extra_smol$blah, rep(10:12, 5)) + for (i in 1:nrow(reg_grid_extra_smol)) { + expect_equal(reg_grid_extra_smol$.submodels[[i]], list(penalty = 1:2)) + } + +}) + diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index c287ccbf6..236d93bd7 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -79,7 +79,6 @@ test_that('primary arguments', { x = expr(missing_arg()), y = expr(missing_arg()), weights = expr(missing_arg()), - lambda = new_empty_quosure(1), family = "binomial" ) ) diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index d2db3d5ed..f08392732 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -62,8 +62,7 @@ test_that('glmnet prediction, one lambda', { uni_pred <- predict(xy_fit$fit, newx = as.matrix(lending_club[1:7, num_pred]), - s = 0.1, type = "response") - uni_pred <- predict(xy_fit$fit, newx = as.matrix(lending_club[1:7, num_pred]), type = "response") + s = 0.1, type = "response")[,1] uni_pred <- ifelse(uni_pred >= 0.5, "good", "bad") uni_pred <- factor(uni_pred, levels = levels(lending_club$Class)) uni_pred <- unname(uni_pred) @@ -83,7 +82,7 @@ test_that('glmnet prediction, one lambda', { form_pred <- predict(res_form$fit, newx = form_mat, - s = 0.1) + s = 0.1, type = "response")[,1] form_pred <- ifelse(form_pred >= 0.5, "good", "bad") form_pred <- factor(form_pred, levels = levels(lending_club$Class)) form_pred <- unname(form_pred) @@ -357,7 +356,7 @@ test_that('glmnet probabilities, no lambda', { expect_equal( mult_pred, multi_predict(xy_fit, lending_club[1:7, num_pred], type = "prob") %>% unnest() - ) + ) res_form <- fit( logistic_reg() %>% set_engine("glmnet"), @@ -419,4 +418,35 @@ test_that('submodel prediction', { }) +test_that('glmnet grid reduction', { + reg_grid <- expand.grid(penalty = 1:3, mixture = (1:5)/5) + reg_grid_smol <- min_grid(logistic_reg() %>% set_engine("glmnet"), reg_grid) + + expect_equal(reg_grid_smol$penalty, rep(3, 5)) + expect_equal(reg_grid_smol$mixture, (1:5)/5) + for (i in 1:nrow(reg_grid_smol)) { + expect_equal(reg_grid_smol$.submodels[[i]], list(penalty = 1:2)) + } + + reg_ish_grid <- expand.grid(penalty = 1:3, mixture = (1:5)/5)[-3,] + reg_ish_grid_smol <- min_grid(logistic_reg() %>% set_engine("glmnet"), reg_ish_grid) + + expect_equal(reg_ish_grid_smol$penalty, c(2, rep(3, 4))) + expect_equal(reg_ish_grid_smol$mixture, (1:5)/5) + expect_equal(reg_ish_grid_smol$.submodels[[1]], list(penalty = 1)) + for (i in 2:nrow(reg_ish_grid_smol)) { + expect_equal(reg_ish_grid_smol$.submodels[[i]], list(penalty = 1:2)) + } + + reg_grid_extra <- expand.grid(penalty = 1:3, mixture = (1:5)/5, blah = 10:12) + reg_grid_extra_smol <- min_grid(logistic_reg() %>% set_engine("glmnet"), reg_grid_extra) + + expect_equal(reg_grid_extra_smol$penalty, rep(3, 15)) + expect_equal(reg_grid_extra_smol$mixture, rep((1:5)/5, each = 3)) + expect_equal(reg_grid_extra_smol$blah, rep(10:12, 5)) + for (i in 1:nrow(reg_grid_extra_smol)) { + expect_equal(reg_grid_extra_smol$.submodels[[i]], list(penalty = 1:2)) + } + +}) diff --git a/tests/testthat/test_logistic_reg_spark.R b/tests/testthat/test_logistic_reg_spark.R index 6c47487c8..a225a3df7 100644 --- a/tests/testthat/test_logistic_reg_spark.R +++ b/tests/testthat/test_logistic_reg_spark.R @@ -87,3 +87,14 @@ test_that('spark execution', { }) + +test_that('spark grid reduction', { + reg_grid <- expand.grid(penalty = 1:3, mixture = (1:5)/5) + reg_grid_smol <- min_grid(logistic_reg() %>% set_engine("spark"), reg_grid) + + expect_equal(reg_grid_smol$penalty, reg_grid$penalty) + expect_equal(reg_grid_smol$mixture, reg_grid$mixture) + for (i in 1:nrow(reg_grid_smol)) { + expect_equal(reg_grid_smol$.submodels[[i]], list()) + } +}) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index b17d46ac7..6c9b88f6c 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -286,3 +286,34 @@ test_that('classification', { expect_equal(parsnip_pred$.pred_good, earth_pred) }) +test_that('earth grid reduction', { + reg_grid <- expand.grid(num_terms = 1:3, prod_degree = 1:2) + reg_grid_smol <- min_grid(mars() %>% set_engine("earth"), reg_grid) + + expect_equal(reg_grid_smol$num_terms, rep(3, 2)) + expect_equal(reg_grid_smol$prod_degree, 1:2) + for (i in 1:nrow(reg_grid_smol)) { + expect_equal(reg_grid_smol$.submodels[[i]], list(num_terms = 1:2)) + } + + reg_ish_grid <- expand.grid(num_terms = 1:3, prod_degree = 1:2)[-3,] + reg_ish_grid_smol <- min_grid(mars() %>% set_engine("earth"), reg_ish_grid) + + expect_equal(reg_ish_grid_smol$num_terms, 2:3) + expect_equal(reg_ish_grid_smol$prod_degree, 1:2) + expect_equal(reg_ish_grid_smol$.submodels[[1]], list(num_terms = 1)) + for (i in 2:nrow(reg_ish_grid_smol)) { + expect_equal(reg_ish_grid_smol$.submodels[[i]], list(num_terms = 1:2)) + } + + reg_grid_extra <- expand.grid(num_terms = 1:3, prod_degree = 1:2, blah = 10:12) + reg_grid_extra_smol <- min_grid(mars() %>% set_engine("earth"), reg_grid_extra) + + expect_equal(reg_grid_extra_smol$num_terms, rep(3, 6)) + expect_equal(reg_grid_extra_smol$prod_degree, rep(1:2, each = 3)) + expect_equal(reg_grid_extra_smol$blah, rep(10:12, 2)) + for (i in 1:nrow(reg_grid_extra_smol)) { + expect_equal(reg_grid_extra_smol$.submodels[[i]], list(num_terms = 1:2)) + } + +}) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 31b8c72bb..eba11a0b6 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -40,7 +40,6 @@ test_that('primary arguments', { x = expr(missing_arg()), y = expr(missing_arg()), weights = expr(missing_arg()), - lambda = new_empty_quosure(1), family = "multinomial" ) ) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index 4115abbda..b612b44ef 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -61,7 +61,6 @@ test_that('glmnet prediction, one lambda', { uni_pred <- factor(uni_pred[,1], levels = levels(iris$Species)) uni_pred <- unname(uni_pred) - expect_equal(uni_pred, parsnip:::predict_class.model_fit(xy_fit, iris[rows, 1:4])) expect_equal(uni_pred, predict(xy_fit, iris[rows, 1:4], type = "class")$.pred_class) res_form <- fit( @@ -151,4 +150,38 @@ test_that('glmnet probabilities, mulitiple lambda', { }) +# ------------------------------------------------------------------------------ + +test_that('glmnet grid reduction', { + reg_grid <- expand.grid(penalty = 1:3, mixture = (1:5)/5) + reg_grid_smol <- min_grid(multinom_reg() %>% set_engine("glmnet"), reg_grid) + + expect_equal(reg_grid_smol$penalty, rep(3, 5)) + expect_equal(reg_grid_smol$mixture, (1:5)/5) + for (i in 1:nrow(reg_grid_smol)) { + expect_equal(reg_grid_smol$.submodels[[i]], list(penalty = 1:2)) + } + + reg_ish_grid <- expand.grid(penalty = 1:3, mixture = (1:5)/5)[-3,] + reg_ish_grid_smol <- min_grid(multinom_reg() %>% set_engine("glmnet"), reg_ish_grid) + + expect_equal(reg_ish_grid_smol$penalty, c(2, rep(3, 4))) + expect_equal(reg_ish_grid_smol$mixture, (1:5)/5) + expect_equal(reg_ish_grid_smol$.submodels[[1]], list(penalty = 1)) + for (i in 2:nrow(reg_ish_grid_smol)) { + expect_equal(reg_ish_grid_smol$.submodels[[i]], list(penalty = 1:2)) + } + + reg_grid_extra <- expand.grid(penalty = 1:3, mixture = (1:5)/5, blah = 10:12) + reg_grid_extra_smol <- min_grid(multinom_reg() %>% set_engine("glmnet"), reg_grid_extra) + + expect_equal(reg_grid_extra_smol$penalty, rep(3, 15)) + expect_equal(reg_grid_extra_smol$mixture, rep((1:5)/5, each = 3)) + expect_equal(reg_grid_extra_smol$blah, rep(10:12, 5)) + for (i in 1:nrow(reg_grid_extra_smol)) { + expect_equal(reg_grid_extra_smol$.submodels[[i]], list(penalty = 1:2)) + } + +}) + diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 3a0039b70..9971b1c64 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -191,3 +191,35 @@ test_that('kknn multi-predict', { dplyr::select(.pred) expect_equal(pred_uni, pred_uni_obs) }) + +test_that('kkknn grid reduction', { + reg_grid <- expand.grid(neighbors = 1:3, prod_degree = 1:2) + reg_grid_smol <- min_grid(nearest_neighbor() %>% set_engine("kknn"), reg_grid) + + expect_equal(reg_grid_smol$neighbors, rep(3, 2)) + expect_equal(reg_grid_smol$prod_degree, 1:2) + for (i in 1:nrow(reg_grid_smol)) { + expect_equal(reg_grid_smol$.submodels[[i]], list(neighbors = 1:2)) + } + + reg_ish_grid <- expand.grid(neighbors = 1:3, prod_degree = 1:2)[-3,] + reg_ish_grid_smol <- min_grid(nearest_neighbor() %>% set_engine("kknn"), reg_ish_grid) + + expect_equal(reg_ish_grid_smol$neighbors, 2:3) + expect_equal(reg_ish_grid_smol$prod_degree, 1:2) + expect_equal(reg_ish_grid_smol$.submodels[[1]], list(neighbors = 1)) + for (i in 2:nrow(reg_ish_grid_smol)) { + expect_equal(reg_ish_grid_smol$.submodels[[i]], list(neighbors = 1:2)) + } + + reg_grid_extra <- expand.grid(neighbors = 1:3, prod_degree = 1:2, blah = 10:12) + reg_grid_extra_smol <- min_grid(nearest_neighbor() %>% set_engine("kknn"), reg_grid_extra) + + expect_equal(reg_grid_extra_smol$neighbors, rep(3, 6)) + expect_equal(reg_grid_extra_smol$prod_degree, rep(1:2, each = 3)) + expect_equal(reg_grid_extra_smol$blah, rep(10:12, 2)) + for (i in 1:nrow(reg_grid_extra_smol)) { + expect_equal(reg_grid_extra_smol$.submodels[[i]], list(neighbors = 1:2)) + } + +}) diff --git a/vignettes/articles/Submodels.Rmd b/vignettes/articles/Submodels.Rmd new file mode 100644 index 000000000..aeb22a468 --- /dev/null +++ b/vignettes/articles/Submodels.Rmd @@ -0,0 +1,118 @@ +--- +title: "Evaluating Submodels with the Same Model Object" +vignette: > + %\VignetteEngine{knitr::rmarkdown} + %\VignetteIndexEntry{Evaluating Submodels with the Same Model Object} +output: + knitr:::html_vignette: + toc: yes +--- + +```{r startup, include = FALSE} +library(utils) +library(ggplot2) +theme_set(theme_bw()) +``` + +Some R packages can create predictions from models that are different than the one that was fit. For example, if a boosted tree is fit with 10 iterations of boosting, the model can usually make predictions on _submodels_ that have less than 10 trees (all other parameters being static). This is helpful for model tuning since you can cheap evaluate tuning parameter combinations and can often results in a large speed-up in the computations. + +In `parsnip`, there is a method called `multi_predict()` that can do this. It's current methods are: + +```{r methods} +library(parsnip) +methods("multi_predict") +``` + +We'll use the attrition data in `rsample` to illustrate: + +```{r} +library(tidymodels) +data(attrition) + +set.seed(4595) +data_split <- initial_split(attrition, strata = "Attrition") +attrition_train <- training(data_split) +attrition_test <- testing(data_split) +``` + +A boosted classification tree is one of the most low-maintenance approaches that we could take to these data: + +```{r boost-model} +# requires the xgboost package +attrition_boost <- + boost_tree(mode = "classification", trees = 100) %>% + set_engine("C5.0") +``` + +Suppose that 10-fold cross-validation was being used to tune the model over the number of trees: + +```{r folds} +set.seed(616) +folds <- vfold_cv(attrition_train) +``` + +The process would fit a model on 90% of the data and predict on the remaining 10%. Using `rsample`: + +```{r fold-1} +model_data <- analysis(folds$splits[[1]]) +pred_data <- assessment(folds$splits[[1]]) + +fold_1_model <- + attrition_boost %>% + fit_xy(x = model_data %>% dplyr::select(-Attrition), y = model_data$Attrition) + +``` + +For `multi_predict()`, the same semantics of `predict()` are used but, for this model, there is an extra argument called `trees`. Candidate submodel values can be passed in with `trees`: + +```{r fold-1-pred} + +fold_1_pred <- + multi_predict( + fold_1_model, + new_data = pred_data %>% dplyr::select(-Attrition), + trees = 1:100, + type = "prob" + ) +fold_1_pred +``` + +The results is a tibble that has as many rows as the data being predicted (_n_ = `r nrow(pred_data)`). The `.pred` column contains a list of tibbles and each has the predictions across the different number of trees: + +```{r obs-1-pred} +fold_1_pred$.pred[[1]] +``` + +To get this into a format that is more usable, we can use `tidyr::unnest()` but we first add row numbers so that we can track the predictions by test sample as well as the actual classes: + +```{r unnest} +fold_1_df <- + fold_1_pred %>% + bind_cols(pred_data %>% dplyr::select(Attrition)) %>% + add_rowindex() %>% + unnest(.pred) +fold_1_df +``` + +For two samples, what do these look like over trees? + +```{r prob-plot} +fold_1_df %>% + dplyr::filter(.row %in% c(1, 88)) %>% + ggplot(aes(x = trees, y = .pred_No, col = Attrition, group = .row)) + + geom_step() + + ylim(0:1) + + theme(legend.position = "top") +``` + +What does performance look like over trees (using the area under the ROC curve)? + +```{r auc-plot} +fold_1_df %>% + group_by(trees) %>% + roc_auc(truth = Attrition, .pred_No) %>% + ggplot(aes(x = trees, y = .estimate)) + + geom_step() +``` + +