diff --git a/DESCRIPTION b/DESCRIPTION index e81c4753a..2b5e812cb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: parsnip -Version: 0.0.3.9001 +Version: 0.0.3.9002 Title: A Common API to Modeling and Analysis Functions Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc). Authors@R: c( @@ -7,7 +7,7 @@ Authors@R: c( person(given = "Davis", family = "Vaughan", email = "davis@rstudio.com", role = c("aut")), person("RStudio", role = "cph")) Maintainer: Max Kuhn -URL: https://tidymodels.github.io/parsnip +URL: https://tidymodels.github.io/parsnip, https://github.com/tidymodels/parsnip BugReports: https://github.com/tidymodels/parsnip/issues License: GPL-2 Encoding: UTF-8 diff --git a/NEWS.md b/NEWS.md index 2f17164a6..696d1518b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,8 +6,10 @@ * A bug was fixed related to the column names generated by `multi_predict()`. The top-level tibble will always have a column named `.pred` and this list column contains tibbles across sub-models. The column names for these sub-model tibbles will have names consistent with `predict()` (which was previously incorrect). See [43c15db](https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2). -# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed -standardizing the column names of `nnet` class probability predictions. +* The model `udpate()` methods gained a `parameters` argument for cases when the parameters are contained in a tibble or list. + +# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed standardizing the column names of `nnet` class probability predictions. + # parsnip 0.0.3.1 diff --git a/R/boost_tree.R b/R/boost_tree.R index ca970c9cc..291700ef6 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -148,6 +148,10 @@ print.boost_tree <- function(x, ...) { #' @export #' @param object A boosted tree model specification. +#' @param parameters A 1-row tibble or named list with _main_ +#' parameters to update. If the individual arguments are used, +#' these will supersede the values in `parameters`. Also, using +#' engine arguments in this object will result in an error. #' @param ... Not used for `update()`. #' @param fresh A logical for whether the arguments should be #' modified in-place of or replaced wholesale. @@ -157,17 +161,31 @@ print.boost_tree <- function(x, ...) { #' model #' update(model, mtry = 1) #' update(model, mtry = 1, fresh = TRUE) +#' +#' param_values <- tibble::tibble(mtry = 10, tree_depth = 5) +#' +#' model %>% update(param_values) +#' model %>% update(param_values, mtry = 3) +#' +#' param_values$verbose <- 0 +#' # Fails due to engine argument +#' # model %>% update(param_values) #' @method update boost_tree #' @rdname boost_tree #' @export update.boost_tree <- function(object, + parameters = NULL, mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, sample_size = NULL, fresh = FALSE, ...) { update_dot_check(...) + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( mtry = enquo(mtry), trees = enquo(trees), @@ -178,6 +196,8 @@ update.boost_tree <- sample_size = enquo(sample_size) ) + args <- update_main_parameters(args, parameters) + # TODO make these blocks into a function and document well if (fresh) { object$args <- args diff --git a/R/decision_tree.R b/R/decision_tree.R index eaf940407..cbcf79ec7 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -137,15 +137,22 @@ print.decision_tree <- function(x, ...) { #' @export update.decision_tree <- function(object, + parameters = NULL, cost_complexity = NULL, tree_depth = NULL, min_n = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } args <- list( cost_complexity = enquo(cost_complexity), tree_depth = enquo(tree_depth), min_n = enquo(min_n) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/linear_reg.R b/R/linear_reg.R index 9d7bd4559..eb053e136 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -169,14 +169,21 @@ translate.linear_reg <- function(x, engine = x$engine, ...) { #' @export update.linear_reg <- function(object, + parameters = NULL, penalty = NULL, mixture = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } args <- list( penalty = enquo(penalty), mixture = enquo(mixture) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 4fd6f9659..00dc2bfe6 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -154,14 +154,21 @@ translate.logistic_reg <- translate.linear_reg #' @export update.logistic_reg <- function(object, + parameters = NULL, penalty = NULL, mixture = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } args <- list( penalty = enquo(penalty), mixture = enquo(mixture) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/mars.R b/R/mars.R index 202ecf822..a98e9b706 100644 --- a/R/mars.R +++ b/R/mars.R @@ -106,16 +106,23 @@ print.mars <- function(x, ...) { #' @export update.mars <- function(object, + parameters = NULL, num_terms = NULL, prod_degree = NULL, prune_method = NULL, fresh = FALSE, ...) { update_dot_check(...) + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( num_terms = enquo(num_terms), prod_degree = enquo(prod_degree), prune_method = enquo(prune_method) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/misc.R b/R/misc.R index 155bb96e3..754bcebc3 100644 --- a/R/misc.R +++ b/R/misc.R @@ -232,3 +232,51 @@ terms_y <- function(x) { y_expr <- att$predvars[[resp_ind + 1]] all.vars(y_expr) } + + +# ------------------------------------------------------------------------------ + +check_final_param <- function(x) { + if (is.null(x)) { + return(invisible(x)) + } + if (!is.list(x) & !tibble::is_tibble(x)) { + rlang::abort("The parameter object should be a list or tibble") + } + if (tibble::is_tibble(x) && nrow(x) > 1) { + rlang::abort("The parameter tibble should have a single row.") + } + if (tibble::is_tibble(x)) { + x <- as.list(x) + } + if (length(names) == 0 || any(names(x) == "")) { + rlang::abort("All values in `parameters` should have a name.") + } + + invisible(x) +} + +update_main_parameters <- function(args, param) { + + if (length(param) == 0) { + return(args) + } + if (length(args) == 0) { + return(param) + } + + # In case an engine argument is included: + has_extra_args <- !(names(param) %in% names(args)) + extra_args <- names(param)[has_extra_args] + if (any(has_extra_args)) { + rlang::abort( + paste("At least one argument is not a main argument:", + paste0("`", extra_args, "`", collapse = ", ")) + ) + } + param <- param[!has_extra_args] + + + + args <- utils::modifyList(args, param) +} diff --git a/R/mlp.R b/R/mlp.R index eb1fb6973..ac7795f54 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -139,10 +139,16 @@ print.mlp <- function(x, ...) { #' @export update.mlp <- function(object, + parameters = NULL, hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( hidden_units = enquo(hidden_units), penalty = enquo(penalty), @@ -151,6 +157,8 @@ update.mlp <- activation = enquo(activation) ) + args <- update_main_parameters(args, parameters) + # TODO make these blocks into a function and document well if (fresh) { object$args <- args diff --git a/R/multinom_reg.R b/R/multinom_reg.R index ea321b723..10175111d 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -137,14 +137,21 @@ translate.multinom_reg <- translate.linear_reg #' @export update.multinom_reg <- function(object, + parameters = NULL, penalty = NULL, mixture = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } args <- list( penalty = enquo(penalty), mixture = enquo(mixture) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index f240d0a2c..8882e8ee5 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -106,17 +106,25 @@ print.nearest_neighbor <- function(x, ...) { #' @export #' @inheritParams update.boost_tree update.nearest_neighbor <- function(object, + parameters = NULL, neighbors = NULL, weight_func = NULL, dist_power = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( neighbors = enquo(neighbors), weight_func = enquo(weight_func), dist_power = enquo(dist_power) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/rand_forest.R b/R/rand_forest.R index 352178e2b..146efe497 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -140,15 +140,22 @@ print.rand_forest <- function(x, ...) { #' @export update.rand_forest <- function(object, + parameters = NULL, mtry = NULL, trees = NULL, min_n = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } args <- list( mtry = enquo(mtry), trees = enquo(trees), min_n = enquo(min_n) ) + args <- update_main_parameters(args, parameters) + # TODO make these blocks into a function and document well if (fresh) { object$args <- args diff --git a/R/surv_reg.R b/R/surv_reg.R index b8966b1e6..14ac72714 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -111,12 +111,19 @@ print.surv_reg <- function(x, ...) { #' @method update surv_reg #' @rdname surv_reg #' @export -update.surv_reg <- function(object, dist = NULL, fresh = FALSE, ...) { +update.surv_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALSE, ...) { update_dot_check(...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( dist = enquo(dist) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/svm_poly.R b/R/svm_poly.R index 5eb071950..026674182 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -106,11 +106,16 @@ print.svm_poly <- function(x, ...) { #' @export update.svm_poly <- function(object, + parameters = NULL, cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, fresh = FALSE, ...) { update_dot_check(...) + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( cost = enquo(cost), degree = enquo(degree), @@ -118,6 +123,8 @@ update.svm_poly <- margin = enquo(margin) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/R/svm_rbf.R b/R/svm_rbf.R index 0fe3d39a6..345a2c11b 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -104,17 +104,24 @@ print.svm_rbf <- function(x, ...) { #' @export update.svm_rbf <- function(object, + parameters = NULL, cost = NULL, rbf_sigma = NULL, margin = NULL, fresh = FALSE, ...) { update_dot_check(...) + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( cost = enquo(cost), rbf_sigma = enquo(rbf_sigma), margin = enquo(margin) ) + args <- update_main_parameters(args, parameters) + if (fresh) { object$args <- args } else { diff --git a/docs/dev/articles/articles/Classification.html b/docs/dev/articles/articles/Classification.html index 3f9a3f45f..a99b16458 100644 --- a/docs/dev/articles/articles/Classification.html +++ b/docs/dev/articles/articles/Classification.html @@ -109,12 +109,12 @@

Classification Example

#> Registered S3 method overwritten by 'xts': #> method from #> as.zoo.xts zoo -#> ── Attaching packages ───────────────────────────────────── tidymodels 0.0.2 ── +#> ── Attaching packages ─────────────────────────────────────────────── tidymodels 0.0.2 ── #> ✔ broom 0.5.1 ✔ purrr 0.3.3 #> ✔ dials 0.0.3.9001 ✔ recipes 0.1.7.9001 #> ✔ dplyr 0.8.3 ✔ rsample 0.0.5 #> ✔ infer 0.4.0 ✔ yardstick 0.0.3.9000 -#> ── Conflicts ──────────────────────────────────────── tidymodels_conflicts() ── +#> ── Conflicts ────────────────────────────────────────────────── tidymodels_conflicts() ── #> ✖ purrr::discard() masks scales::discard() #> ✖ dplyr::filter() masks stats::filter() #> ✖ dplyr::lag() masks stats::lag() @@ -185,17 +185,17 @@

Classification Example

#> # A tibble: 1 x 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> -#> 1 roc_auc binary 0.825 +#> 1 roc_auc binary 0.824 test_results %>% accuracy(truth = Status, nnet_class) #> # A tibble: 1 x 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> -#> 1 accuracy binary 0.806 +#> 1 accuracy binary 0.801 test_results %>% conf_mat(truth = Status, nnet_class) #> Truth #> Prediction bad good -#> bad 193 96 -#> good 120 704 +#> bad 185 93 +#> good 128 707 +#>
+param_values <- tibble::tibble(mtry = 10, tree_depth = 5) + +model %>% update(param_values)
#> Boosted Tree Model Specification (unknown) +#> +#> Main Arguments: +#> mtry = 10 +#> min_n = 3 +#> tree_depth = 5 +#>
model %>% update(param_values, mtry = 3)
#> Boosted Tree Model Specification (unknown) +#> +#> Main Arguments: +#> mtry = 10 +#> min_n = 3 +#> tree_depth = 5 +#>
+param_values$verbose <- 0 +# Fails due to engine argument +# model %>% update(param_values)