Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: parsnip
Version: 0.0.2.9000
Version: 0.0.3
Title: A Common API to Modeling and Analysis Functions
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
Authors@R: c(
Expand Down
16 changes: 16 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ S3method(fit_xy,model_spec)
S3method(has_multi_predict,default)
S3method(has_multi_predict,model_fit)
S3method(has_multi_predict,workflow)
S3method(min_grid,boost_tree)
S3method(min_grid,linear_reg)
S3method(min_grid,logistic_reg)
S3method(min_grid,mars)
S3method(min_grid,multinom_reg)
S3method(min_grid,nearest_neighbor)
S3method(multi_predict,"_C5.0")
S3method(multi_predict,"_earth")
S3method(multi_predict,"_elnet")
Expand Down Expand Up @@ -50,8 +56,11 @@ S3method(print,svm_rbf)
S3method(translate,boost_tree)
S3method(translate,decision_tree)
S3method(translate,default)
S3method(translate,linear_reg)
S3method(translate,logistic_reg)
S3method(translate,mars)
S3method(translate,mlp)
S3method(translate,multinom_reg)
S3method(translate,nearest_neighbor)
S3method(translate,rand_forest)
S3method(translate,surv_reg)
Expand Down Expand Up @@ -104,6 +113,13 @@ export(linear_reg)
export(logistic_reg)
export(make_classes)
export(mars)
export(min_grid)
export(min_grid.boost_tree)
export(min_grid.linear_reg)
export(min_grid.logistic_reg)
export(min_grid.mars)
export(min_grid.multinom_reg)
export(min_grid.nearest_neighbor)
export(mlp)
export(model_printer)
export(multi_predict)
Expand Down
17 changes: 13 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
# parsnip 0.0.2.9000
# parsnip 0.0.3

Unplanned release based on CRAN requirements for Solaris.

## Breaking Changes

* The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html).
* The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation).
* The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html).

* The mode needs to be declared for models that can be used for more than one mode prior to fitting and/or translation.

* For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`.

* For `glmnet` models, the full regularization path is always fit regardless of the value given to `penalty`. Previously, the model was fit with passing `penalty` to `glmnet`'s `lambda` argument and the model could only make predictions at those specific values. [(#195)](https://github.com/tidymodels/parsnip/issues/195)

## New Features

* `add_rowindex()` can create a column called `.row` to a data frame.

* If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero.
* `nearest_neighbor` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.

* `nearest_neighbor()` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.

* A suite of internal functions were added to help with upcoming model tuning features.


# parsnip 0.0.2
Expand Down
94 changes: 93 additions & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,102 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
}

# ------------------------------------------------------------------------------
# min_grid generic - put here so that the generic shows up first in the man file

#' Determine the minimum set of model fits
#'
#' `min_grid` determines exactly what models should be fit in order to
#' evaluate the entire set of tuning parameter combinations. This is for
#' internal use only and the API may change in the near future.
#' @param x A model specification.
#' @param grid A tibble with tuning parameter combinations.
#' @param ... Not currently used.
#' @return A tibble with the minimum tuning parameters to fit and an additional
#' list column with the parameter combinations used for prediction.
#' @keywords internal
#' @export
min_grid <- function(x, grid, ...) {
# x is a `model_spec` object from parsnip
# grid is a tibble of tuning parameter values with names
# matching the parameter names.
UseMethod("min_grid")
}

# As an example, if we fit a boosted tree model and tune over
# trees = 1:20 and min_n = c(20, 30)
# we should only have to fit two models:
#
# trees = 20 & min_n = 20
# trees = 20 & min_n = 30
#
# The logic related to how this "mini grid" gets made is model-specific.
#
# To get the full set of predictions, we need to know, for each of these two
# models, what values of num_terms to give to the multi_predict() function.
#
# The current idea is to have a list column of the extra models for prediction.
# For the example above:
#
# # A tibble: 2 x 3
# trees min_n .submodels
# <dbl> <dbl> <list>
# 1 20 20 <named list [1]>
# 2 20 30 <named list [1]>
#
# and the .submodels would both be
#
# list(trees = 1:19)
#
# There are a lot of other things to consider in future versions like grids
# where there are multiple columns with the same name (maybe the results of
# a recipe) and so on.

# ------------------------------------------------------------------------------
# helper functions

# Template for model results that do no have the sub-model feature
blank_submodels <- function(grid) {
grid %>%
dplyr::mutate(.submodels = map(1:nrow(grid), ~ list()))
}

get_fixed_args <- function(info) {
# Get non-sub-model columns to iterate over
fixed_args <- info$name[!info$has_submodel]
}

get_submodel_info <- function(spec, grid) {
param_info <-
get_from_env(paste0(class(spec)[1], "_args")) %>%
dplyr::filter(engine == spec$engine) %>%
dplyr::select(name = parsnip, has_submodel)

# In case a recipe or other activity has grid parameter columns,
# add those to the results
grid_names <- names(grid)
is_mod_param <- grid_names %in% param_info$name
if (any(!is_mod_param)) {
param_info <-
param_info %>%
dplyr::bind_rows(
tibble::tibble(name = grid_names[!is_mod_param],
has_submodel = FALSE)
)
}
param_info %>% dplyr::filter(name %in% grid_names)
}


# ------------------------------------------------------------------------------
# nocov

#' @importFrom utils globalVariables
utils::globalVariables(
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
"neighbors")
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
"max_terms", "max_tree", "name", "num_terms", "penalty", "trees",
"sub_neighbors")
)

# nocov end
38 changes: 38 additions & 0 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,41 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {
pred[, c(".row", "trees", nms)]
}

# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.boost_tree
#' @rdname min_grid
min_grid.boost_tree <- function(x, grid, ...) {
grid_names <- names(grid)
param_info <- get_submodel_info(x, grid)

# No ability to do submodels? Finish here:
if (!any(param_info$has_submodel)) {
return(blank_submodels(grid))
}

fixed_args <- get_fixed_args(param_info)

# For boosted trees, fit the model with the most trees (conditional on the
# other parameters) so that you can do predictions on the smaller models.
fit_only <-
grid %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(trees = max(trees, na.rm = TRUE)) %>%
dplyr::ungroup()

# Add a column .submodels that is a list with what should be predicted
# by `multi_predict()` (assuming `predict()` has already been executed
# on the original value of 'trees')
min_grid_df <-
dplyr::full_join(fit_only %>% rename(max_tree = trees), grid, by = fixed_args) %>%
dplyr::filter(trees != max_tree) %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(.submodels = list(list(trees = trees))) %>%
dplyr::ungroup() %>%
dplyr::full_join(fit_only, grid, by = fixed_args)

min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
}

69 changes: 65 additions & 4 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
#'
#' When using `glmnet` models, there is the option to pass
#' multiple values (or no values) to the `penalty` argument. This
#' can have an effect on the model object results. When using the
#' For `glmnet` models, the full regularization path is always fit regardless
#' of the value given to `penalty`. Also, there is the option to pass
#' multiple values (or no values) to the `penalty` argument. When using the
#' `predict()` method in these cases, the return value depends on
#' the value of `penalty`. When using `predict()`, only a single
#' value of the penalty can be used. When predicting on multiple
Expand Down Expand Up @@ -138,6 +138,23 @@ print.linear_reg <- function(x, ...) {
invisible(x)
}


#' @export
translate.linear_reg <- function(x, engine = x$engine, ...) {
x <- translate.default(x, engine, ...)

if (engine == "glmnet") {
# See discussion in https://github.com/tidymodels/parsnip/issues/195
x$method$fit$args$lambda <- NULL
# Since the `fit` infomration is gone for the penalty, we need to have an
# evaludated value for the parameter.
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
}

x
}


# ------------------------------------------------------------------------------

#' @inheritParams update.boost_tree
Expand Down Expand Up @@ -274,6 +291,11 @@ predict._elnet <-
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

# See discussion in https://github.com/tidymodels/parsnip/issues/195
if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
penalty <- object$spec$args$penalty
}

object$spec$args$penalty <- check_penalty(penalty, object, multi)

object$spec <- eval_args(object$spec)
Expand Down Expand Up @@ -314,7 +336,12 @@ multi_predict._elnet <-
object$spec <- eval_args(object$spec)

if (is.null(penalty)) {
penalty <- object$fit$lambda
# See discussion in https://github.com/tidymodels/parsnip/issues/195
if (!is.null(object$spec$args$penalty)) {
penalty <- object$spec$args$penalty
} else {
penalty <- object$fit$lambda
}
}

pred <- predict._elnet(object, new_data = new_data, type = "raw",
Expand All @@ -332,3 +359,37 @@ multi_predict._elnet <-
names(pred) <- NULL
tibble(.pred = pred)
}


# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.linear_reg
#' @rdname min_grid
min_grid.linear_reg <- function(x, grid, ...) {

grid_names <- names(grid)
param_info <- get_submodel_info(x, grid)

if (!any(param_info$has_submodel)) {
return(blank_submodels(grid))
}

fixed_args <- get_fixed_args(param_info)

fit_only <-
grid %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(penalty = max(penalty, na.rm = TRUE)) %>%
dplyr::ungroup()

min_grid_df <-
dplyr::full_join(fit_only %>% rename(max_penalty = penalty), grid, by = fixed_args) %>%
dplyr::filter(penalty != max_penalty) %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(.submodels = list(list(penalty = penalty))) %>%
dplyr::ungroup() %>%
dplyr::full_join(fit_only, grid, by = fixed_args)

min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
}
Loading