From b9cc6576587f476ba205915934153c67e4344af4 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 22 Oct 2025 10:41:25 -0400 Subject: [PATCH 01/15] initial definitions for classification and regression --- R/rand_forest.R | 43 +++-- R/rand_forest_data.R | 449 +++++++++++++++++++++++++++++++++---------- 2 files changed, 371 insertions(+), 121 deletions(-) diff --git a/R/rand_forest.R b/R/rand_forest.R index f15b5e906..c810a07b0 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -34,12 +34,17 @@ #' @export rand_forest <- - function(mode = "unknown", engine = "ranger", mtry = NULL, trees = NULL, min_n = NULL) { - + function( + mode = "unknown", + engine = "ranger", + mtry = NULL, + trees = NULL, + min_n = NULL + ) { args <- list( - mtry = enquo(mtry), - trees = enquo(trees), - min_n = enquo(min_n) + mtry = enquo(mtry), + trees = enquo(trees), + min_n = enquo(min_n) ) new_model_spec( @@ -60,15 +65,19 @@ rand_forest <- #' @rdname parsnip_update #' @export update.rand_forest <- - function(object, - parameters = NULL, - mtry = NULL, trees = NULL, min_n = NULL, - fresh = FALSE, ...) { - + function( + object, + parameters = NULL, + mtry = NULL, + trees = NULL, + min_n = NULL, + fresh = FALSE, + ... + ) { args <- list( - mtry = enquo(mtry), - trees = enquo(trees), - min_n = enquo(min_n) + mtry = enquo(mtry), + trees = enquo(trees), + min_n = enquo(min_n) ) update_spec( @@ -109,8 +118,10 @@ translate.rand_forest <- function(x, engine = x$engine, ...) { # See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy` # should be character even if it contains a number. - if (any(names(arg_vals) == "feature_subset_strategy") && - isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) { + if ( + any(names(arg_vals) == "feature_subset_strategy") && + isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy))) + ) { arg_vals$feature_subset_strategy <- paste(quo_get_expr(arg_vals$feature_subset_strategy)) } @@ -118,7 +129,6 @@ translate.rand_forest <- function(x, engine = x$engine, ...) { # add checks to error trap or change things for this method if (engine == "ranger") { - if (any(names(arg_vals) == "importance")) { if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) { cli::cli_abort( @@ -170,4 +180,3 @@ check_args.rand_forest <- function(object, call = rlang::caller_env()) { # move translate checks here? invisible(object) } - diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 598bceaa9..76894bdb7 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -1,8 +1,12 @@ # wrappers for ranger ranger_class_pred <- - function(results, object) { + function(results, object) { if (results$treetype == "Probability estimation") { - res <- colnames(results$predictions)[apply(results$predictions, 1, which.max)] + res <- colnames(results$predictions)[apply( + results$predictions, + 1, + which.max + )] } else { res <- results$predictions } @@ -10,27 +14,38 @@ ranger_class_pred <- } ranger_num_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2 const <- qnorm(hf_lvl, lower.tail = FALSE) res <- tibble( - .pred = predict(object$fit, data = new_data, type = "response", ...)$predictions + .pred = predict( + object$fit, + data = new_data, + type = "response", + ... + )$predictions ) std_error <- predict(object$fit, data = new_data, type = "se", ...)$se res$.pred_lower <- res$.pred - const * std_error res$.pred_upper <- res$.pred + const * std_error res$.pred <- NULL - if (object$spec$method$pred$conf_int$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) { res$.std_error <- std_error + } res } ranger_class_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2 const <- qnorm(hf_lvl, lower.tail = FALSE) - pred <- predict(object$fit, data = new_data, type = "response", ...)$predictions + pred <- predict( + object$fit, + data = new_data, + type = "response", + ... + )$predictions pred <- as_tibble(pred) std_error <- predict(object$fit, data = new_data, type = "se", ...)$se @@ -51,8 +66,9 @@ ranger_class_confint <- function(object, new_data, ...) { col_names <- paste0(c(".pred_lower_", ".pred_upper_"), lvl) res <- res[, col_names] - if (object$spec$method$pred$conf_int$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) { res <- bind_cols(res, std_error) + } res } @@ -76,11 +92,79 @@ ranger_confint <- function(object, new_data, ...) { # ------------------------------------------------------------------------------ +grf_prob_convert <- function(x, object) { + lvls <- levels(object$fit$Y.orig) + x <- x$predictions + colnames(x) <- lvls + tibble::as_tibble(x, .name_repair = "minimal") +} + +grf_cls_convert <- function(x, object) { + res <- grf_prob_convert(x, object) + cls_ind <- apply(res, 1, which.max) + lvls <- levels(object$fit$Y.orig) + res <- lvls[cls_ind] + res <- factor(res, levels = lvls) +} + +grf_conf_int <- function( + object, + new_data, + std_err = FALSE +) { + raw_pred <- predict(object$fit, new_data, estimate.variance = TRUE) + + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level) / 2 + + std_err <- sqrt(raw_pred$variance.estimates) + + const <- stats::qnorm(hf_lvl) + + if (inherits(object$fit, "regression_forest")) { + res <- + tibble( + .pred_lower = raw_pred$predictions - const * std_err, + .pred_upper = raw_pred$predictions + const * std_err + ) + + if (object$spec$method$pred$conf_int$extras$std_error) { + res$.std_error <- std_err + } + } else if (inherits(object$fit, "probability_forest")) { + lowers <- raw_pred$predictions - const * std_err + uppers <- raw_pred$predictions + const * std_err + + lowers <- tibble::as_tibble(lowers, .name_repair = "minimal") + uppers <- tibble::as_tibble(uppers, .name_repair = "minimal") + + names(lowers) <- paste0(".pred_lower_", names(lowers)) + names(uppers) <- paste0(".pred_upper_", names(uppers)) + + res <- vctrs::vec_cbind(lowers, uppers) + + if (object$spec$method$pred$conf_int$extras$std_error) { + std_err <- tibble::as_tibble(std_err, .name_repair = "minimal") + names(std_err) <- paste0(".std_error_", names(std_err)) + res <- vctrs::vec_cbind(res, std_err) + } + } else { + rlang::abort( + "No confidence interval implementation for objects with class(es) + {.cls {class(object$fit)[1]}}" + ) + } + + res +} + +# ------------------------------------------------------------------------------ + set_new_model("rand_forest") set_model_mode("rand_forest", "classification") set_model_mode("rand_forest", "regression") set_model_mode("rand_forest", "censored regression") +set_model_mode("rand_forest", "quantile regression") # ------------------------------------------------------------------------------ # ranger components @@ -124,12 +208,11 @@ set_fit( data = c(x = "x", y = "y", weights = "case.weights"), protect = c("x", "y", "weights"), func = c(pkg = "ranger", fun = "ranger"), - defaults = - list( - num.threads = 1, - verbose = FALSE, - seed = expr(sample.int(10 ^ 5, 1)) - ) + defaults = list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10^5, 1)) + ) ) ) @@ -154,12 +237,11 @@ set_fit( data = c(x = "x", y = "y", weights = "case.weights"), protect = c("x", "y", "weights"), func = c(pkg = "ranger", fun = "ranger"), - defaults = - list( - num.threads = 1, - verbose = FALSE, - seed = expr(sample.int(10 ^ 5, 1)) - ) + defaults = list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10^5, 1)) + ) ) ) @@ -184,14 +266,13 @@ set_pred( pre = NULL, post = ranger_class_pred, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10 ^ 5, 1)), - verbose = FALSE - ) + args = list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10^5, 1)), + verbose = FALSE + ) ) ) @@ -202,7 +283,7 @@ set_pred( type = "prob", value = list( pre = function(x, object) { - if (object$fit$forest$treetype != "Probability estimation") + if (object$fit$forest$treetype != "Probability estimation") { cli::cli_abort( c( "`ranger` model does not appear to use class probabilities.", @@ -210,6 +291,7 @@ set_pred( ), call = call2("predict") ) + } x }, post = function(x, object) { @@ -217,13 +299,12 @@ set_pred( as_tibble(x) }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10 ^ 5, 1)), - verbose = FALSE - ) + args = list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10^5, 1)), + verbose = FALSE + ) ) ) @@ -236,12 +317,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "ranger_confint"), - args = - list( - object = quote(object), - new_data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) + args = list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) ) @@ -254,12 +334,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10 ^ 5, 1)) - ) + args = list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) ) @@ -270,21 +349,20 @@ set_pred( type = "numeric", value = list( pre = NULL, - post = function(results, object) - results$predictions, + post = function(results, object) { + results$predictions + }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10 ^ 5, 1)), - verbose = FALSE - ) + args = list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10^5, 1)), + verbose = FALSE + ) ) ) - set_pred( model = "rand_forest", eng = "ranger", @@ -294,12 +372,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "ranger_confint"), - args = - list( - object = quote(object), - new_data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) + args = list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) ) set_pred( @@ -311,12 +388,11 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10 ^ 5, 1)) - ) + args = list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) ) @@ -324,9 +400,19 @@ set_pred( # randomForest components set_model_engine("rand_forest", "classification", "randomForest") -set_model_engine("rand_forest", "regression", "randomForest") -set_dependency("rand_forest", "randomForest", "randomForest", mode = "regression") -set_dependency("rand_forest", "randomForest", "randomForest", mode = "classification") +set_model_engine("rand_forest", "regression", "randomForest") +set_dependency( + "rand_forest", + "randomForest", + "randomForest", + mode = "regression" +) +set_dependency( + "rand_forest", + "randomForest", + "randomForest", + mode = "classification" +) set_model_arg( model = "rand_forest", @@ -361,8 +447,7 @@ set_fit( interface = "data.frame", protect = c("x", "y"), func = c(pkg = "randomForest", fun = "randomForest"), - defaults = - list() + defaults = list() ) ) @@ -386,8 +471,7 @@ set_fit( interface = "data.frame", protect = c("x", "y"), func = c(pkg = "randomForest", fun = "randomForest"), - defaults = - list() + defaults = list() ) ) @@ -412,9 +496,7 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list(object = quote(object$fit), - newdata = quote(new_data)) + args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) @@ -427,9 +509,7 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list(object = quote(object$fit), - newdata = quote(new_data)) + args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) @@ -458,12 +538,11 @@ set_pred( as_tibble(as.data.frame(x)) }, func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) + args = list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) ) ) @@ -476,9 +555,7 @@ set_pred( pre = NULL, post = NULL, func = c(fun = "predict"), - args = - list(object = quote(object$fit), - newdata = quote(new_data)) + args = list(object = quote(object$fit), newdata = quote(new_data)) ) ) @@ -523,7 +600,7 @@ set_fit( data = c(formula = "formula", data = "x"), protect = c("x", "formula", "type"), func = c(pkg = "sparklyr", fun = "ml_random_forest"), - defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + defaults = list(seed = expr(sample.int(10^5, 1))) ) ) @@ -548,7 +625,7 @@ set_fit( data = c(formula = "formula", data = "x"), protect = c("x", "formula", "type"), func = c(pkg = "sparklyr", fun = "ml_random_forest"), - defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + defaults = list(seed = expr(sample.int(10^5, 1))) ) ) @@ -573,9 +650,7 @@ set_pred( pre = NULL, post = format_spark_num, func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list(x = quote(object$fit), - dataset = quote(new_data)) + args = list(x = quote(object$fit), dataset = quote(new_data)) ) ) @@ -588,9 +663,7 @@ set_pred( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list(x = quote(object$fit), - dataset = quote(new_data)) + args = list(x = quote(object$fit), dataset = quote(new_data)) ) ) @@ -603,8 +676,176 @@ set_pred( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list(x = quote(object$fit), - dataset = quote(new_data)) + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) + + +# ------------------------------------------------------------------------------ +# grf components + +set_model_engine("rand_forest", mode = "classification", eng = "grf") +set_model_engine("rand_forest", mode = "regression", eng = "grf") +set_model_engine("rand_forest", mode = "quantile regression", eng = "grf") +set_dependency("rand_forest", "grf", "grf", mode = "classification") +set_dependency("rand_forest", "grf", "grf", mode = "regression") +set_dependency("rand_forest", "grf", "grf", mode = "quantile regression") + +set_model_arg( + model = "rand_forest", + eng = "grf", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "grf", + parsnip = "trees", + original = "num.trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "grf", + parsnip = "min_n", + original = "min.node.size", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "rand_forest", + eng = "grf", + mode = "classification", + value = list( + interface = "data.frame", + data = c(x = "X", y = "Y", weights = "sample.weights"), + protect = c("x", "y", "weights"), + func = c(pkg = "grf", fun = "probability_forest"), + defaults = list( + num.threads = 1 + ) + ) +) + +set_encoding( + model = "rand_forest", + eng = "grf", + mode = "classification", + options = list( + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) +) + +set_fit( + model = "rand_forest", + eng = "grf", + mode = "regression", + value = list( + interface = "data.frame", + data = c(x = "X", y = "Y", weights = "case.weights"), + protect = c("x", "y", "weights"), + func = c(pkg = "grf", fun = "regression_forest"), + defaults = list( + num.threads = 1 + ) + ) +) + +set_encoding( + model = "rand_forest", + eng = "grf", + mode = "regression", + options = list( + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) +) + +set_pred( + model = "rand_forest", + eng = "grf", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = grf_cls_convert, + func = c(fun = "predict"), + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "grf", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = grf_prob_convert, + func = c(fun = "predict"), + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "grf", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "grf_conf_int"), + args = list( + object = quote(object), + new_data = quote(new_data) + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "grf", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(results, object) results$predictions, + func = c(fun = "predict"), + args = list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "grf", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "grf_conf_int"), + args = list( + object = quote(object), + new_data = quote(new_data) + ) ) ) From bd92eedafa21d4c7409ddd2fbac1782b833aad26 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 22 Oct 2025 10:50:19 -0400 Subject: [PATCH 02/15] enable quantile regression --- R/rand_forest_data.R | 49 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 76894bdb7..6d90081a1 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -157,6 +157,10 @@ grf_conf_int <- function( res } +qrf_quantile_convert <- function(x, object) { + matrix_to_quantile_pred(x$predictions, object) +} + # ------------------------------------------------------------------------------ set_new_model("rand_forest") @@ -849,3 +853,48 @@ set_pred( ) ) ) + +set_fit( + model = "rand_forest", + eng = "grf", + mode = "quantile regression", + value = list( + interface = "data.frame", + data = c(x = "X", y = "Y", weights = "case.weights"), + protect = c("x", "y", "weights"), + func = c(pkg = "grf", fun = "quantile_forest"), + defaults = list( + num.threads = 1, + quantiles = quote(quantile_levels) + ) + ) +) + +set_encoding( + model = "rand_forest", + eng = "grf", + mode = "quantile regression", + options = list( + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = TRUE, + allow_sparse_x = FALSE + ) +) + +set_pred( + model = "rand_forest", + eng = "grf", + mode = "quantile regression", + type = "quantile", + value = list( + pre = NULL, + post = qrf_quantile_convert, + func = c(fun = "predict"), + args = list( + object = expr(object$fit), + newdata = expr(new_data), + quantiles = NULL + ) + ) +) From 791601d8a8c237f11694d736bea939ffa0da44d7 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 22 Oct 2025 13:19:02 -0400 Subject: [PATCH 03/15] documentation --- NEWS.md | 2 + R/aaa_archive.R | 8 +- R/rand_forest_grf.R | 13 +++ man/augment.Rd | 19 ++++ man/details_rand_forest_grf.Rd | 168 ++++++++++++++++++++++++++++++++ man/rmd/rand_forest_grf.Rmd | 105 ++++++++++++++++++++ man/rmd/rand_forest_grf.md | 146 +++++++++++++++++++++++++++ vignettes/articles/Examples.Rmd | 118 ++++++++++++++++++++++ 8 files changed, 577 insertions(+), 2 deletions(-) create mode 100644 R/rand_forest_grf.R create mode 100644 man/details_rand_forest_grf.Rd create mode 100644 man/rmd/rand_forest_grf.Rmd create mode 100644 man/rmd/rand_forest_grf.md diff --git a/NEWS.md b/NEWS.md index be6c445f2..bb9f9fc59 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # parsnip (development version) +* Enable generalized random forest (`grf`) models for classification, regression, and quantile regression modes. (#1288) + # parsnip 1.3.3 * Bug fix in how tunable parameters were configured for brulee neural networks. diff --git a/R/aaa_archive.R b/R/aaa_archive.R index 01e0397bf..81c6f8611 100644 --- a/R/aaa_archive.R +++ b/R/aaa_archive.R @@ -1,4 +1,4 @@ -# no fmt +# fmt: skip model_info_table <- tibble::tribble( ~model, ~mode, ~engine, ~pkg, @@ -21,6 +21,7 @@ model_info_table <- "bag_tree", "classification", "rpart", "baguette", "bart", "classification", "dbarts", NA, "boost_tree", "classification", "C5.0", NA, + "boost_tree", "classification", "catboost", "bonsai", "boost_tree", "classification", "h2o", "agua", "boost_tree", "classification", "h2o_gbm", "agua", "boost_tree", "classification", "lightgbm", "bonsai", @@ -69,6 +70,7 @@ model_info_table <- "null_model", "classification", "parsnip", NA, "pls", "classification", "mixOmics", "plsmod", "rand_forest", "classification", "aorsf", "bonsai", + "rand_forest", "classification", "grf", NA, "rand_forest", "classification", "h2o", "agua", "rand_forest", "classification", "partykit", "bonsai", "rand_forest", "classification", "randomForest", NA, @@ -82,11 +84,13 @@ model_info_table <- "svm_rbf", "classification", "kernlab", NA, "svm_rbf", "classification", "liquidSVM", NA, "linear_reg", "quantile regression", "quantreg", NA, + "rand_forest", "quantile regression", "grf", NA, "auto_ml", "regression", "h2o", "agua", "bag_mars", "regression", "earth", "baguette", "bag_mlp", "regression", "nnet", "baguette", "bag_tree", "regression", "rpart", "baguette", "bart", "regression", "dbarts", NA, + "boost_tree", "regression", "catboost", "bonsai", "boost_tree", "regression", "h2o", "agua", "boost_tree", "regression", "h2o_gbm", "agua", "boost_tree", "regression", "lightgbm", "bonsai", @@ -130,6 +134,7 @@ model_info_table <- "poisson_reg", "regression", "stan_glmer", "multilevelmod", "poisson_reg", "regression", "zeroinfl", "poissonreg", "rand_forest", "regression", "aorsf", "bonsai", + "rand_forest", "regression", "grf", NA, "rand_forest", "regression", "h2o", "agua", "rand_forest", "regression", "partykit", "bonsai", "rand_forest", "regression", "randomForest", NA, @@ -145,4 +150,3 @@ model_info_table <- "svm_rbf", "regression", "kernlab", NA, "svm_rbf", "regression", "liquidSVM", NA ) - diff --git a/R/rand_forest_grf.R b/R/rand_forest_grf.R new file mode 100644 index 000000000..f7b14388e --- /dev/null +++ b/R/rand_forest_grf.R @@ -0,0 +1,13 @@ +#' Random forests via grf +#' +#' The \pkg{grf} fits models that create a large number of decision +#' trees, each independent of the others. The final prediction uses all +#' predictions from the individual trees and combines them. +#' +#' @includeRmd man/rmd/rand_forest_grf.md details +#' +#' @name details_rand_forest_grf +#' @keywords internal +NULL + +# See inst/README-DOCS.md for a description of how these files are processed diff --git a/man/augment.Rd b/man/augment.Rd index 90d777159..821c803e2 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -47,6 +47,13 @@ probability of censoring weights (IPCW) are also created (see \code{tidymodels.o page in the references below). This enables the user to compute performance metrics in the \pkg{yardstick} package. } + +\subsection{Quantile Regression}{ + +For quantile regression models, a \code{.pred_quantile} column is added that +contains the quantile predictions for each row. This column has a special +class \code{"quantile_pred"} and can be unnested using \code{\link[tidyr:unnest]{tidyr::unnest()}} +} } \examples{ \dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("modeldata")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} @@ -89,6 +96,18 @@ augment(cls_form, cls_tst[, -3]) augment(cls_xy, cls_tst) augment(cls_xy, cls_tst[, -3]) + +# ------------------------------------------------------------------------------ + +# Quantile regression example +qr_form <- + linear_reg() |> + set_engine("quantreg") |> + set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |> + fit(mpg ~ ., data = car_trn) + +augment(qr_form, car_tst) +augment(qr_form, car_tst[, -1]) \dontshow{\}) # examplesIf} } \references{ diff --git a/man/details_rand_forest_grf.Rd b/man/details_rand_forest_grf.Rd new file mode 100644 index 000000000..081cd05c8 --- /dev/null +++ b/man/details_rand_forest_grf.Rd @@ -0,0 +1,168 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/rand_forest_grf.R +\name{details_rand_forest_grf} +\alias{details_rand_forest_grf} +\title{Random forests via grf} +\description{ +The \pkg{grf} fits models that create a large number of decision +trees, each independent of the others. The final prediction uses all +predictions from the individual trees and combines them. +} +\details{ +For this engine, there are multiple modes: classification, regression, +and quantile regression +\subsection{Tuning Parameters}{ + +This model has 3 tuning parameters: +\itemize{ +\item \code{mtry}: # Randomly Selected Predictors (type: integer, default: see +below) +\item \code{trees}: # Trees (type: integer, default: 2000L) +\item \code{min_n}: Minimal Node Size (type: integer, default: 5L) +} + +\code{mtry} depends on the number of columns. If there are \code{p} predictors, +the default value of \code{mtry} is \code{min(ceiling(sqrt(p) + 20), p)}. +} + +\subsection{Translation from parsnip to the original package (regression)}{ + +See +\href{\%22https://grf-labs.github.io/grf/reference/regression_forest.html}{\code{?regression_forest}} + +\if{html}{\out{
}}\preformatted{rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("regression") |> + translate() +}\if{html}{\out{
}} + +\if{html}{\out{
}}\preformatted{## Random Forest Model Specification (regression) +## +## Main Arguments: +## mtry = integer(1) +## trees = integer(1) +## min_n = integer(1) +## +## Computational engine: grf +## +## Model fit template: +## grf::regression_forest(x = missing_arg(), y = missing_arg(), +## weights = missing_arg(), mtry = min_cols(~integer(1), x), +## num.trees = integer(1), min.node.size = min_rows(~integer(1), +## x), num.threads = 1) +}\if{html}{\out{
}} +} + +\subsection{Translation from parsnip to the original package (classification)}{ + +See +\href{\%22https://grf-labs.github.io/grf/reference/probability_forest.html}{\code{?probability_forest}} + +\if{html}{\out{
}}\preformatted{rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("classification") |> + translate() +}\if{html}{\out{
}} + +\if{html}{\out{
}}\preformatted{## Random Forest Model Specification (classification) +## +## Main Arguments: +## mtry = integer(1) +## trees = integer(1) +## min_n = integer(1) +## +## Computational engine: grf +## +## Model fit template: +## grf::probability_forest(x = missing_arg(), y = missing_arg(), +## weights = missing_arg(), mtry = min_cols(~integer(1), x), +## num.trees = integer(1), min.node.size = min_rows(~integer(1), +## x), num.threads = 1) +}\if{html}{\out{
}} +} + +\subsection{Translation from parsnip to the original package (quantile regression)}{ + +See +\href{\%22https://grf-labs.github.io/grf/reference/quantile_forest.html}{\code{?quantile_forest}} + +When specifying \emph{any} quantile regression model, the user must specify +the quantile levels \emph{a priori}. + +\if{html}{\out{
}}\preformatted{rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("quantile regression", quantile_levels = (1:3) / 4) |> + translate() +}\if{html}{\out{
}} + +\if{html}{\out{
}}\preformatted{## Random Forest Model Specification (quantile regression) +## +## Main Arguments: +## mtry = integer(1) +## trees = integer(1) +## min_n = integer(1) +## +## Computational engine: grf +## +## Model fit template: +## grf::quantile_forest(x = missing_arg(), y = missing_arg(), weights = missing_arg(), +## mtry = min_cols(~integer(1), x), num.trees = integer(1), +## min.node.size = min_rows(~integer(1), x), num.threads = 1, +## quantiles = quantile_levels) + +## Quantile levels: 0.25, 0.5, and 0.75. +}\if{html}{\out{
}} +} + +\subsection{Preprocessing requirements}{ + +This method \emph{does} require qualitative predictors to be converted to a +numeric format (manually). When using parsnip, a one-hot encoding is +automatically used to do this. +} + +\subsection{Other notes}{ + +By default, parallel processing is turned off. When tuning, it is more +efficient to parallelize over the resamples and tuning parameters. To +parallelize the construction of the trees within the \code{grf} model, change +the \code{num.threads} argument via \code{\link[=set_engine]{set_engine()}}. + +For \code{grf} confidence intervals, the intervals are constructed using the +form \verb{estimate +/- z * std_error}. For classification probabilities, +these values can fall outside of \verb{[0, 1]} and will be coerced to be in +this range. +} + +\subsection{Case weights}{ + +The regression and classification models enable the use of case weights. +The quantile regression mode does not. +} + +\subsection{Examples}{ + +The “Fitting and Predicting with parsnip” article contains +\href{https://parsnip.tidymodels.org/articles/articles/Examples.html#rand-forest-grf}{examples} +for \code{rand_forest()} with the \code{"grf"} engine. +} + +\subsection{References}{ + +Athey, Susan, Julie Tibshirani, and Stefan Wager. “Generalized Random +Forests”. \emph{Annals of Statistics}, 47(2), 2019. +} +} +\keyword{internal} diff --git a/man/rmd/rand_forest_grf.Rmd b/man/rmd/rand_forest_grf.Rmd new file mode 100644 index 000000000..05f26ffe4 --- /dev/null +++ b/man/rmd/rand_forest_grf.Rmd @@ -0,0 +1,105 @@ +```{r} +#| child: aaa.Rmd +#| include: false +``` + +`r descr_models("rand_forest", "grf")` + +## Tuning Parameters + +```{r} +#| label: grf-param-info +#| echo: false +defaults <- + tibble::tibble(parsnip = c("mtry", "trees", "min_n"), + default = c("see below", "2000L", "5L")) + +param <- + rand_forest() |> + set_engine("grf") |> + make_parameter_list(defaults) +``` + +This model has `r nrow(param)` tuning parameters: + +```{r} +#| label: grf-param-list +#| echo: false +#| results: asis +param$item +``` + +`mtry` depends on the number of columns. If there are `p` predictors, the default value of `mtry` is `min(ceiling(sqrt(p) + 20), p)`. + +## Translation from parsnip to the original package (regression) + +See [`?regression_forest`]("https://grf-labs.github.io/grf/reference/regression_forest.html) + +```{r} +#| label: grf-reg +rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("regression") |> + translate() +``` + +## Translation from parsnip to the original package (classification) + +See [`?probability_forest`]("https://grf-labs.github.io/grf/reference/probability_forest.html) + +```{r} +#| label: grf-cls +rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("classification") |> + translate() +``` + +## Translation from parsnip to the original package (quantile regression) + +See [`?quantile_forest`]("https://grf-labs.github.io/grf/reference/quantile_forest.html) + +When specifying _any_ quantile regression model, the user must specify the quantile levels _a priori_. + +```{r} +#| label: grf-quant +rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("quantile regression", quantile_levels = (1:3) / 4) |> + translate() +``` + +## Preprocessing requirements + +This method _does_ require qualitative predictors to be converted to a numeric format (manually). When using parsnip, a one-hot encoding is automatically used to do this. + +## Other notes + +By default, parallel processing is turned off. When tuning, it is more efficient to parallelize over the resamples and tuning parameters. To parallelize the construction of the trees within the `grf` model, change the `num.threads` argument via [set_engine()]. + +For `grf` confidence intervals, the intervals are constructed using the form `estimate +/- z * std_error`. For classification probabilities, these values can fall outside of `[0, 1]` and will be coerced to be in this range. + +## Case weights + +The regression and classification models enable the use of case weights. The quantile regression mode does not. + +## Examples + +The "Fitting and Predicting with parsnip" article contains [examples](https://parsnip.tidymodels.org/articles/articles/Examples.html#rand-forest-grf) for `rand_forest()` with the `"grf"` engine. + +## References + +Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized Random Forests". _Annals of Statistics_, 47(2), 2019. + diff --git a/man/rmd/rand_forest_grf.md b/man/rmd/rand_forest_grf.md new file mode 100644 index 000000000..307c3246f --- /dev/null +++ b/man/rmd/rand_forest_grf.md @@ -0,0 +1,146 @@ + + + +For this engine, there are multiple modes: classification, regression, and quantile regression + +## Tuning Parameters + + + +This model has 3 tuning parameters: + +- `mtry`: # Randomly Selected Predictors (type: integer, default: see below) + +- `trees`: # Trees (type: integer, default: 2000L) + +- `min_n`: Minimal Node Size (type: integer, default: 5L) + +`mtry` depends on the number of columns. If there are `p` predictors, the default value of `mtry` is `min(ceiling(sqrt(p) + 20), p)`. + +## Translation from parsnip to the original package (regression) + +See [`?regression_forest`]("https://grf-labs.github.io/grf/reference/regression_forest.html) + + +``` r +rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("regression") |> + translate() +``` + +``` +## Random Forest Model Specification (regression) +## +## Main Arguments: +## mtry = integer(1) +## trees = integer(1) +## min_n = integer(1) +## +## Computational engine: grf +## +## Model fit template: +## grf::regression_forest(x = missing_arg(), y = missing_arg(), +## weights = missing_arg(), mtry = min_cols(~integer(1), x), +## num.trees = integer(1), min.node.size = min_rows(~integer(1), +## x), num.threads = 1) +``` + +## Translation from parsnip to the original package (classification) + +See [`?probability_forest`]("https://grf-labs.github.io/grf/reference/probability_forest.html) + + +``` r +rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("classification") |> + translate() +``` + +``` +## Random Forest Model Specification (classification) +## +## Main Arguments: +## mtry = integer(1) +## trees = integer(1) +## min_n = integer(1) +## +## Computational engine: grf +## +## Model fit template: +## grf::probability_forest(x = missing_arg(), y = missing_arg(), +## weights = missing_arg(), mtry = min_cols(~integer(1), x), +## num.trees = integer(1), min.node.size = min_rows(~integer(1), +## x), num.threads = 1) +``` + +## Translation from parsnip to the original package (quantile regression) + +See [`?quantile_forest`]("https://grf-labs.github.io/grf/reference/quantile_forest.html) + +When specifying _any_ quantile regression model, the user must specify the quantile levels _a priori_. + + +``` r +rand_forest( + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) |> + set_engine("grf") |> + set_mode("quantile regression", quantile_levels = (1:3) / 4) |> + translate() +``` + +``` +## Random Forest Model Specification (quantile regression) +## +## Main Arguments: +## mtry = integer(1) +## trees = integer(1) +## min_n = integer(1) +## +## Computational engine: grf +## +## Model fit template: +## grf::quantile_forest(x = missing_arg(), y = missing_arg(), weights = missing_arg(), +## mtry = min_cols(~integer(1), x), num.trees = integer(1), +## min.node.size = min_rows(~integer(1), x), num.threads = 1, +## quantiles = quantile_levels) +``` + +``` +## Quantile levels: 0.25, 0.5, and 0.75. +``` + +## Preprocessing requirements + +This method _does_ require qualitative predictors to be converted to a numeric format (manually). When using parsnip, a one-hot encoding is automatically used to do this. + +## Other notes + +By default, parallel processing is turned off. When tuning, it is more efficient to parallelize over the resamples and tuning parameters. To parallelize the construction of the trees within the `grf` model, change the `num.threads` argument via [set_engine()]. + +For `grf` confidence intervals, the intervals are constructed using the form `estimate +/- z * std_error`. For classification probabilities, these values can fall outside of `[0, 1]` and will be coerced to be in this range. + +## Case weights + +The regression and classification models enable the use of case weights. The quantile regression mode does not. + +## Examples + +The "Fitting and Predicting with parsnip" article contains [examples](https://parsnip.tidymodels.org/articles/articles/Examples.html#rand-forest-grf) for `rand_forest()` with the `"grf"` engine. + +## References + +Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized Random Forests". _Annals of Statistics_, 47(2), 2019. + diff --git a/vignettes/articles/Examples.Rmd b/vignettes/articles/Examples.Rmd index c885a4507..d03c30109 100644 --- a/vignettes/articles/Examples.Rmd +++ b/vignettes/articles/Examples.Rmd @@ -1688,7 +1688,125 @@ The following examples use consistent data sets throughout. For regression, we u +
+ + With the `"grf"` engine + +

Regression Example (`grf`)

+ + ```{r} + #| echo: false + knitr::spin_child("template-reg-chicago.R") + ``` + + We can define the model with specific parameters: + + ```{r} + rf_reg_spec <- + rand_forest(trees = 200, min_n = 5) |> + # This model can be used for classification, regression, or quantile + # regression so set mode + set_mode("regression") |> + set_engine("grf") + rf_reg_spec + ``` + + Now we create the model fit object: + + ```{r} + set.seed(1) + rf_reg_fit <- rf_reg_spec |> fit(ridership ~ ., data = Chicago_train) + rf_reg_fit + ``` + + The holdout data can be predicted for their mean value as well as confidence intervals for the predictions: + ```{r} + predict(rf_reg_fit, Chicago_test) + predict(rf_reg_fit, Chicago_test, type = "conf_int") + ``` + + +

Classification Example (`grf`)

+ + ```{r} + #| echo: false + knitr::spin_child("template-cls-two-class.R") + ``` + + We can define the model with specific parameters: + + ```{r} + rf_cls_spec <- + rand_forest(trees = 200, min_n = 5) |> + # This model can be used for classification, regression, or quantile + # regression so set mode + set_mode("classification") |> + set_engine("grf") + rf_cls_spec + ``` + + Now we create the model fit object: + + ```{r} + set.seed(1) + rf_cls_fit <- rf_cls_spec |> fit(Class ~ ., data = data_train) + rf_cls_fit + ``` + + The holdout data can be predicted for both hard class predictions, probabilities, and confidence intervals. We'll bind these together into one tibble: + + ```{r} + bind_cols( + predict(rf_cls_fit, data_test), + predict(rf_cls_fit, data_test, type = "prob") + predict(rf_cls_fit, data_test, type = "conf_int") + ) + ``` + +

Quantile regression Example (`grf`)

+ + ```{r} + #| echo: false + knitr::spin_child("template-reg-sacramento.R") + ``` + + We can define the model but should set the model mode. Also, for these models the levels of the distirunbtion that we would like to predict need to specified with the mode using the `quantile_levels` argument. Let's predict the 0.25, 0.50, and 0.75 quantiles: + + ```{r} + grf_quant_spec <- + linear_reg() |> + set_engine("grf") |> + set_mode("quantile regression", quantile_levels = (1:3) / 4) + grf_quant_spec + ``` + + Now we create the model fit object: + + ```{r} + set.seed(1) + grf_quant_fit <- grf_quant_spec |> fit(price ~ sqft, data = sac_train) + grf_quant_fit + ``` + + The holdout data can be predicted: + + ```{r} + quant_pred <- predict(grf_quant_fit, sac_test) + quant_pred + ``` + + `.pred_quantile` is a vector type that contains all of the quartile predictions for each row. You can convert this to a rectangular data set using either of: + + ```{r} + as.matrix(quant_pred$.pred_quantile) + + # or + as_tibble(quant_pred$.pred_quantile) + ``` + +
+ ## `svm_linear()` models
From 1151a2cfce1cd0c83c9fbac93641697e00e0e9b0 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 22 Oct 2025 13:30:38 -0400 Subject: [PATCH 04/15] testing update --- tests/testthat/_snaps/args_and_modes.md | 2 +- tests/testthat/_snaps/rand_forest.md | 2 +- tests/testthat/_snaps/registration.md | 45 +++++---- tests/testthat/helper-objects.R | 16 +-- tests/testthat/test-registration.R | 127 ++++++++++++++---------- 5 files changed, 113 insertions(+), 79 deletions(-) diff --git a/tests/testthat/_snaps/args_and_modes.md b/tests/testthat/_snaps/args_and_modes.md index a8e1cb44c..445529066 100644 --- a/tests/testthat/_snaps/args_and_modes.md +++ b/tests/testthat/_snaps/args_and_modes.md @@ -12,7 +12,7 @@ set_mode(rand_forest()) Condition Error in `set_mode()`: - ! Available modes for model type rand_forest are: "unknown", "classification", "regression", and "censored regression". + ! Available modes for model type rand_forest are: "unknown", "classification", "regression", "censored regression", and "quantile regression". --- diff --git a/tests/testthat/_snaps/rand_forest.md b/tests/testthat/_snaps/rand_forest.md index bd9fc5d1c..28bfca8fd 100644 --- a/tests/testthat/_snaps/rand_forest.md +++ b/tests/testthat/_snaps/rand_forest.md @@ -21,7 +21,7 @@ res <- translate(set_engine(rand_forest(mode = "classification"), NULL)) Condition Error in `set_engine()`: - ! Missing engine. Possible mode/engine combinations are: classification {ranger, randomForest, spark} and regression {ranger, randomForest, spark}. + ! Missing engine. Possible mode/engine combinations are: classification {ranger, randomForest, spark, grf}, quantile regression {grf}, and regression {ranger, randomForest, spark, grf}. --- diff --git a/tests/testthat/_snaps/registration.md b/tests/testthat/_snaps/registration.md index 575f18df6..9e90b65b7 100644 --- a/tests/testthat/_snaps/registration.md +++ b/tests/testthat/_snaps/registration.md @@ -363,11 +363,12 @@ show_model_info("rand_forest") Output Information for `rand_forest` - modes: unknown, classification, regression, censored regression + modes: unknown, classification, regression, censored regression, quantile regression engines: - classification: randomForest, ranger1, spark - regression: randomForest, ranger1, spark + classification: grf1, randomForest, ranger1, spark + quantile regression: grf1 + regression: grf1, randomForest, ranger1, spark 1The model can use case weights. @@ -384,24 +385,34 @@ mtry --> feature_subset_strategy trees --> num_trees min_n --> min_instances_per_node + grf: + mtry --> mtry + trees --> num.trees + min_n --> min.node.size fit modules: - engine mode - ranger classification - ranger regression - randomForest classification - randomForest regression - spark classification - spark regression + engine mode + ranger classification + ranger regression + randomForest classification + randomForest regression + spark classification + spark regression + grf classification + grf regression + grf quantile regression prediction modules: - mode engine methods - classification randomForest class, prob, raw - classification ranger class, conf_int, prob, raw - classification spark class, prob - regression randomForest numeric, raw - regression ranger conf_int, numeric, raw - regression spark numeric + mode engine methods + classification grf class, conf_int, prob + classification randomForest class, prob, raw + classification ranger class, conf_int, prob, raw + classification spark class, prob + quantile regression grf quantile + regression grf conf_int, numeric + regression randomForest numeric, raw + regression ranger conf_int, numeric, raw + regression spark numeric --- diff --git a/tests/testthat/helper-objects.R b/tests/testthat/helper-objects.R index 86f61ed78..068c34ca5 100644 --- a/tests/testthat/helper-objects.R +++ b/tests/testthat/helper-objects.R @@ -1,6 +1,6 @@ -ctrl <- control_parsnip(verbosity = 1, catch = FALSE) -caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) -quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) +ctrl <- control_parsnip(verbosity = 1, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0 @@ -29,7 +29,7 @@ if (rlang::is_installed("modeldata")) { # ------------------------------------------------------------------------------ - hpc <- hpc_data[1:150, c(2:5, 8)] + hpc <- modeldata::hpc_data[1:150, c(2:5, 8)] num_hpc_pred <- names(hpc)[1:4] class_tab <- table(hpc$class, dnn = NULL) hpc_bad <- @@ -37,7 +37,7 @@ if (rlang::is_installed("modeldata")) { dplyr::mutate(big_num = Inf) set.seed(352) - mlp_dat <- hpc[order(runif(150)),] + mlp_dat <- hpc[order(runif(150)), ] tr_mlp_dat <- mlp_dat[1:140, ] te_mlp_dat <- mlp_dat[141:150, ] @@ -46,7 +46,7 @@ if (rlang::is_installed("modeldata")) { mlp_hpc_pred_list <- names(hpc)[1:4] nnet_hpc_pred_list <- names(hpc)[1:4] - hpc_nnet_dat <- hpc_data[1:150, c(2:5, 8)] + hpc_nnet_dat <- modeldata::hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ @@ -56,7 +56,7 @@ if (rlang::is_installed("modeldata")) { fit(compounds ~ ., data = hpc) lending_club <- - lending_club |> + modeldata::lending_club |> dplyr::slice(1:200) |> dplyr::mutate(big_num = Inf) @@ -73,7 +73,7 @@ if (rlang::is_installed("modeldata")) { dplyr::select(price, beds, baths, sqft, latitude, longitude) sac_train <- Sacramento_small[-(1:5), ] - sac_test <- Sacramento_small[ 1:5 , ] + sac_test <- Sacramento_small[1:5, ] # ------------------------------------------------------------------------------ # For sparse tibble testing diff --git a/tests/testthat/test-registration.R b/tests/testthat/test-registration.R index 7120f5177..f6e4d43cc 100644 --- a/tests/testthat/test-registration.R +++ b/tests/testthat/test-registration.R @@ -3,8 +3,14 @@ test_that('adding a new model', { mod_items <- get_model_env() |> rlang::env_names() sponges <- grep("sponge", mod_items, value = TRUE) - exp_obj <- c('sponge_modes', 'sponge_fit', 'sponge_args', - 'sponge_predict', 'sponge_pkgs', 'sponge') + exp_obj <- c( + 'sponge_modes', + 'sponge_fit', + 'sponge_args', + 'sponge_predict', + 'sponge_pkgs', + 'sponge' + ) expect_equal(sort(sponges), sort(exp_obj)) expect_equal( @@ -12,36 +18,45 @@ test_that('adding a new model', { tibble(engine = character(0), mode = character(0)) ) -expect_equal( - get_from_env("sponge_pkgs"), - tibble(engine = character(0), pkg = list(), mode = character(0)) -) - -expect_equal( - get_from_env("sponge_modes"), "unknown" -) - -expect_equal( - get_from_env("sponge_args"), - dplyr::tibble(engine = character(0), parsnip = character(0), - original = character(0), func = vector("list"), - has_submodel = logical(0)) -) - -expect_equal( - get_from_env("sponge_fit"), - tibble(engine = character(0), mode = character(0), value = vector("list")) -) - -expect_equal( - get_from_env("sponge_predict"), - tibble(engine = character(0), mode = character(0), - type = character(0), value = vector("list")) -) - -expect_snapshot(error = TRUE, set_new_model()) -expect_snapshot(error = TRUE, set_new_model(2)) -expect_snapshot(error = TRUE, set_new_model(letters[1:2])) + expect_equal( + get_from_env("sponge_pkgs"), + tibble(engine = character(0), pkg = list(), mode = character(0)) + ) + + expect_equal( + get_from_env("sponge_modes"), + "unknown" + ) + + expect_equal( + get_from_env("sponge_args"), + dplyr::tibble( + engine = character(0), + parsnip = character(0), + original = character(0), + func = vector("list"), + has_submodel = logical(0) + ) + ) + + expect_equal( + get_from_env("sponge_fit"), + tibble(engine = character(0), mode = character(0), value = vector("list")) + ) + + expect_equal( + get_from_env("sponge_predict"), + tibble( + engine = character(0), + mode = character(0), + type = character(0), + value = vector("list") + ) + ) + + expect_snapshot(error = TRUE, set_new_model()) + expect_snapshot(error = TRUE, set_new_model(2)) + expect_snapshot(error = TRUE, set_new_model(letters[1:2])) }) @@ -58,7 +73,6 @@ test_that('adding a new mode', { expect_equal(get_from_env("sponge_modes"), c("unknown", "classification")) expect_snapshot(error = TRUE, set_model_mode("sponge")) - }) @@ -75,7 +89,10 @@ test_that('adding a new engine', { expect_equal(get_from_env("sponge_modes"), c("unknown", "classification")) expect_snapshot(error = TRUE, set_model_engine("sponge", eng = "gum")) - expect_snapshot(error = TRUE, set_model_engine("sponge", mode = "classification")) + expect_snapshot( + error = TRUE, + set_model_engine("sponge", mode = "classification") + ) expect_snapshot( error = TRUE, set_model_engine("sponge", mode = "regression", eng = "gum") @@ -90,7 +107,10 @@ test_that('adding a new package', { expect_snapshot(error = TRUE, set_dependency("sponge", "gum", letters[1:2])) expect_snapshot(error = TRUE, set_dependency("sponge", "gummies", "trident")) - expect_snapshot(error = TRUE, set_dependency("sponge", "gum", "trident", mode = "regression")) + expect_snapshot( + error = TRUE, + set_dependency("sponge", "gum", "trident", mode = "regression") + ) expect_equal( get_from_env("sponge_pkgs"), @@ -100,16 +120,20 @@ test_that('adding a new package', { set_dependency("sponge", "gum", "juicy-fruit", mode = "classification") expect_equal( get_from_env("sponge_pkgs"), - tibble(engine = "gum", - pkg = list(c("trident", "juicy-fruit")), - mode = "classification") + tibble( + engine = "gum", + pkg = list(c("trident", "juicy-fruit")), + mode = "classification" + ) ) expect_equal( get_dependency("sponge"), - tibble(engine = "gum", - pkg = list(c("trident", "juicy-fruit")), - mode = "classification") + tibble( + engine = "gum", + pkg = list(c("trident", "juicy-fruit")), + mode = "classification" + ) ) }) @@ -140,9 +164,13 @@ test_that('adding a new argument', { expect_equal( get_from_env("sponge_args"), - tibble(engine = "gum", parsnip = "modeling", original = "modelling", - func = list(list(pkg = "foo", fun = "bar")), - has_submodel = FALSE) + tibble( + engine = "gum", + parsnip = "modeling", + original = "modelling", + func = list(list(pkg = "foo", fun = "bar")), + has_submodel = FALSE + ) ) expect_snapshot( @@ -252,7 +280,6 @@ test_that('adding a new argument', { }) - # ------------------------------------------------------------------------------ test_that('adding a new fit', { @@ -273,7 +300,7 @@ test_that('adding a new fit', { fit_env_data <- get_from_env("sponge_fit") expect_equal( - fit_env_data[ 1:2], + fit_env_data[1:2], tibble(engine = "gum", mode = "classification") ) @@ -405,7 +432,7 @@ test_that('adding a new predict method', { pred_env_data <- get_from_env("sponge_predict") expect_equal( - pred_env_data[ 1:3], + pred_env_data[1:3], tibble(engine = "gum", mode = "classification", type = "class") ) @@ -415,7 +442,7 @@ test_that('adding a new predict method', { ) expect_equal( - get_pred_type("sponge", "class")[ 1:3], + get_pred_type("sponge", "class")[1:3], tibble(engine = "gum", mode = "classification", type = "class") ) @@ -446,7 +473,6 @@ test_that('adding a new predict method', { ) ) - expect_snapshot( error = TRUE, set_pred( @@ -520,11 +546,9 @@ test_that('adding a new predict method', { value = class_vals_2 ) ) - }) - test_that('showing model info', { expect_snapshot(show_model_info("rand_forest")) @@ -532,4 +556,3 @@ test_that('showing model info', { # notation would be ambiguous (#1000) expect_snapshot(show_model_info("mlp")) }) - From b5284b0a8d55489a003bc5b6a934de85c06334c7 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 22 Oct 2025 13:32:31 -0400 Subject: [PATCH 05/15] snapshot updates --- tests/testthat/_snaps/adds.md | 8 + tests/testthat/_snaps/boost_tree.md | 40 +++ tests/testthat/_snaps/boost_tree_C5.0.md | 36 +++ tests/testthat/_snaps/boost_tree_xgboost.md | 102 +++++++ tests/testthat/_snaps/convert_data.md | 68 +++++ tests/testthat/_snaps/decision_tree.md | 54 ++++ tests/testthat/_snaps/descriptors.md | 8 + tests/testthat/_snaps/extract.md | 56 ++++ tests/testthat/_snaps/failed_models.md | 51 ++++ tests/testthat/_snaps/fit_interfaces.md | 78 ++++++ tests/testthat/_snaps/gen_additive_model.md | 19 ++ tests/testthat/_snaps/linear_reg.md | 233 ++++++++++++++++ tests/testthat/_snaps/linear_reg_quantreg.md | 9 + tests/testthat/_snaps/logistic_reg.md | 184 +++++++++++++ tests/testthat/_snaps/mars.md | 68 +++++ tests/testthat/_snaps/misc.md | 257 ++++++++++++++++++ tests/testthat/_snaps/mlp.md | 132 +++++++++ tests/testthat/_snaps/mlp_keras.md | 18 ++ tests/testthat/_snaps/mlp_nnet.md | 8 + tests/testthat/_snaps/multinom_reg.md | 121 +++++++++ .../testthat/_snaps/nearest_neighbor_kknn.md | 26 ++ tests/testthat/_snaps/nullmodel.md | 48 ++++ tests/testthat/_snaps/predict_formats.md | 36 +++ tests/testthat/_snaps/rand_forest_ranger.md | 90 ++++++ tests/testthat/_snaps/sparsevctrs.md | 145 ++++++++++ tests/testthat/_snaps/svm_linear.md | 65 +++++ tests/testthat/_snaps/svm_poly.md | 33 +++ tests/testthat/_snaps/svm_rbf.md | 53 ++++ tests/testthat/test-rand_forest.R | 18 +- 29 files changed, 2058 insertions(+), 6 deletions(-) create mode 100644 tests/testthat/_snaps/adds.md create mode 100644 tests/testthat/_snaps/boost_tree_C5.0.md create mode 100644 tests/testthat/_snaps/boost_tree_xgboost.md create mode 100644 tests/testthat/_snaps/convert_data.md create mode 100644 tests/testthat/_snaps/decision_tree.md create mode 100644 tests/testthat/_snaps/descriptors.md create mode 100644 tests/testthat/_snaps/extract.md create mode 100644 tests/testthat/_snaps/failed_models.md create mode 100644 tests/testthat/_snaps/fit_interfaces.md create mode 100644 tests/testthat/_snaps/gen_additive_model.md create mode 100644 tests/testthat/_snaps/linear_reg.md create mode 100644 tests/testthat/_snaps/linear_reg_quantreg.md create mode 100644 tests/testthat/_snaps/logistic_reg.md create mode 100644 tests/testthat/_snaps/mars.md create mode 100644 tests/testthat/_snaps/misc.md create mode 100644 tests/testthat/_snaps/mlp.md create mode 100644 tests/testthat/_snaps/mlp_keras.md create mode 100644 tests/testthat/_snaps/mlp_nnet.md create mode 100644 tests/testthat/_snaps/multinom_reg.md create mode 100644 tests/testthat/_snaps/nearest_neighbor_kknn.md create mode 100644 tests/testthat/_snaps/nullmodel.md create mode 100644 tests/testthat/_snaps/predict_formats.md create mode 100644 tests/testthat/_snaps/rand_forest_ranger.md create mode 100644 tests/testthat/_snaps/sparsevctrs.md create mode 100644 tests/testthat/_snaps/svm_linear.md create mode 100644 tests/testthat/_snaps/svm_poly.md create mode 100644 tests/testthat/_snaps/svm_rbf.md diff --git a/tests/testthat/_snaps/adds.md b/tests/testthat/_snaps/adds.md new file mode 100644 index 000000000..22f574c07 --- /dev/null +++ b/tests/testthat/_snaps/adds.md @@ -0,0 +1,8 @@ +# adding row indicies + + Code + add_rowindex(as.matrix(mtcars)) + Condition + Error in `add_rowindex()`: + ! `x` should be a data frame. + diff --git a/tests/testthat/_snaps/boost_tree.md b/tests/testthat/_snaps/boost_tree.md index 827d73d76..d4b75aa71 100644 --- a/tests/testthat/_snaps/boost_tree.md +++ b/tests/testthat/_snaps/boost_tree.md @@ -1,3 +1,43 @@ +# updating + + Code + update(set_engine(boost_tree(trees = 1), "C5.0", noGlobalPruning = TRUE), + trees = tune(), noGlobalPruning = tune()) + Output + Boosted Tree Model Specification (unknown mode) + + Main Arguments: + trees = tune() + + Engine-Specific Arguments: + noGlobalPruning = tune() + + Computational engine: C5.0 + + +# bad input + + Code + boost_tree(mode = "bogus") + Condition + Error in `boost_tree()`: + ! "bogus" is not a known mode for model `boost_tree()`. + +--- + + Code + translate(boost_tree(mode = "classification"), engine = NULL) + Message + Used `engine = 'xgboost'` for translation. + Output + Boosted Tree Model Specification (classification) + + Computational engine: xgboost + + Model fit template: + parsnip::xgb_train(x = missing_arg(), y = missing_arg(), weights = missing_arg(), + nthread = 1, verbose = 0) + # check_args() works Code diff --git a/tests/testthat/_snaps/boost_tree_C5.0.md b/tests/testthat/_snaps/boost_tree_C5.0.md new file mode 100644 index 000000000..e8e8e95db --- /dev/null +++ b/tests/testthat/_snaps/boost_tree_C5.0.md @@ -0,0 +1,36 @@ +# C5.0 execution + + Code + res <- fit(lc_basic, funded_amnt ~ term, data = lending_club, engine = "C5.0", + control = ctrl) + Condition + Error in `.convert_form_to_xy_fit()`: + ! The argument `engine` cannot be used to create the data. + Possible arguments are subset or weights. + +# submodel prediction + + Code + multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 4, type = "prob") + Condition + Error in `multi_predict()`: + ! Please use `new_data` instead of `newdata`. + +# argument checks for data dimensions + + Code + f_fit <- fit(spec, species ~ ., data = penguins) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 will be used. + +--- + + Code + xy_fit <- fit_xy(spec, x = penguins[, -1], y = penguins$species) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 will be used. + diff --git a/tests/testthat/_snaps/boost_tree_xgboost.md b/tests/testthat/_snaps/boost_tree_xgboost.md new file mode 100644 index 000000000..135b9a2d0 --- /dev/null +++ b/tests/testthat/_snaps/boost_tree_xgboost.md @@ -0,0 +1,102 @@ +# xgboost execution, classification + + Code + res <- parsnip::fit(hpc_xgboost, class ~ novar, data = hpc, control = ctrl) + Condition + Error: + ! object 'novar' not found + +# submodel prediction + + Code + multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 5, type = "prob") + Condition + Error in `multi_predict()`: + ! Please use `new_data` instead of `newdata`. + +# validation sets + + Code + reg_fit <- fit(set_engine(boost_tree(trees = 20, mode = "regression"), + "xgboost", validation = 3), mpg ~ ., data = mtcars[-(1:4), ]) + Condition + Error in `parsnip::xgb_train()`: + ! `validation` should be on [0, 1). + +# early stopping + + Code + reg_fit <- fit(set_engine(boost_tree(trees = 20, stop_iter = 30, mode = "regression"), + "xgboost", validation = 0.1), mpg ~ ., data = mtcars[-(1:4), ]) + Condition + Warning: + `early_stop` was reduced to 19. + +--- + + Code + reg_fit <- fit(set_engine(boost_tree(trees = 20, stop_iter = 0, mode = "regression"), + "xgboost", validation = 0.1), mpg ~ ., data = mtcars[-(1:4), ]) + Condition + Error in `parsnip::xgb_train()`: + ! `early_stop` should be on [2, 20). + +# xgboost data conversion + + Code + from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars_y, event_level = "second") + Condition + Warning: + `event_level` can only be set for binary outcomes. + +# argument checks for data dimensions + + Code + f_fit <- fit(spec, species ~ ., data = penguins, control = ctrl) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 will be used. + +--- + + Code + xy_fit <- fit_xy(spec, x = penguins_dummy, y = penguins$species, control = ctrl) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 will be used. + +# count/proportion parameters + + Code + fit(set_mode(set_engine(boost_tree(mtry = 0.9, trees = 4), "xgboost"), + "regression"), mpg ~ ., data = mtcars) + Condition + Error in `xgb_train()`: + ! The option `counts = TRUE` was used but `colsample_bynode` was given as 0.9. + i Please use a value >= 1 or use `counts = FALSE`. + +# interface to param arguments + + ! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`. + i See `?details_boost_tree_xgboost` for more information. + +--- + + ! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`. + i See `?details_boost_tree_xgboost` for more information. + +--- + + ! The argument `watchlist` is guarded by parsnip and will not be passed to `xgb.train()`. + +--- + + ! The arguments `watchlist` and `data` are guarded by parsnip and will not be passed to `xgb.train()`. + +--- + + ! Please supply elements of the `params` list argument as main arguments to `set_engine()` rather than as part of `params`. + i See `?details_boost_tree_xgboost` for more information. + diff --git a/tests/testthat/_snaps/convert_data.md b/tests/testthat/_snaps/convert_data.md new file mode 100644 index 000000000..d864f282a --- /dev/null +++ b/tests/testthat/_snaps/convert_data.md @@ -0,0 +1,68 @@ +# numeric y and mixed x, fail missing data + + Code + .convert_form_to_xy_fit(rate ~ ., data = Puromycin_miss, na.action = na.fail, + indicators = "traditional", remove_intercept = TRUE) + Condition + Error in `na.fail.default()`: + ! missing values in object + +# numeric x and factor y + + Code + expected <- glm(class ~ ., data = hpc, x = TRUE, y = TRUE, family = binomial()) + Condition + Warning: + glm.fit: fitted probabilities numerically 0 or 1 occurred + +# bad args + + Code + .convert_form_to_xy_fit(mpg ~ ., data = mtcars, composition = "tibble", + indicators = "traditional", remove_intercept = TRUE) + Condition + Error: + ! `composition` should be either "data.frame", "matrix", or "dgCMatrix". + +--- + + Code + .convert_form_to_xy_fit(mpg ~ ., data = mtcars, weights = letters[1:nrow(mtcars)], + indicators = "traditional", remove_intercept = TRUE) + Condition + Error: + ! `weights` must be a numeric vector. + +--- + + Code + .convert_xy_to_form_fit(mtcars$disp, mtcars$mpg, remove_intercept = TRUE) + Condition + Error: + ! `x` cannot be a vector. + +--- + + Code + .convert_xy_to_form_fit(mtcars[, 1:3], mtcars[, 2:5], remove_intercept = TRUE) + Condition + Error in `.convert_xy_to_form_fit()`: + ! `x` and `y` have the names "cyl" and "disp" in common. + i Please ensure that `x` and `y` don't share any column names. + +# convert to matrix + + Code + parsnip::maybe_matrix(ames[, c("Year_Built", "Neighborhood")]) + Condition + Error in `parsnip::maybe_matrix()`: + ! The column "Neighborhood" is non-numeric, so the data cannot be converted to a numeric matrix. + +--- + + Code + parsnip::maybe_matrix(Chicago[, c("ridership", "date")]) + Condition + Error in `parsnip::maybe_matrix()`: + ! The column "date" is non-numeric, so the data cannot be converted to a numeric matrix. + diff --git a/tests/testthat/_snaps/decision_tree.md b/tests/testthat/_snaps/decision_tree.md new file mode 100644 index 000000000..d40576a7e --- /dev/null +++ b/tests/testthat/_snaps/decision_tree.md @@ -0,0 +1,54 @@ +# updating + + Code + update(set_engine(decision_tree(cost_complexity = 0.1), "rpart", model = FALSE), + cost_complexity = tune(), model = tune()) + Output + Decision Tree Model Specification (unknown mode) + + Main Arguments: + cost_complexity = tune() + + Engine-Specific Arguments: + model = tune() + + Computational engine: rpart + + +# bad input + + "bogus" is not a known mode for model `decision_tree()`. + +--- + + Please set the mode in the model specification (`?parsnip::model_spec()`). + +--- + + Please set the mode in the model specification (`?parsnip::model_spec()`). + +--- + + Code + try(translate(decision_tree(), engine = NULL), silent = TRUE) + Message + Used `engine = 'rpart'` for translation. + +# argument checks for data dimensions + + Code + f_fit <- fit(spec, body_mass_g ~ ., data = penguins) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 samples will be used. + +--- + + Code + xy_fit <- fit_xy(spec, x = penguins[, -6], y = penguins$body_mass_g) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 samples will be used. + diff --git a/tests/testthat/_snaps/descriptors.md b/tests/testthat/_snaps/descriptors.md new file mode 100644 index 000000000..b0bc248f8 --- /dev/null +++ b/tests/testthat/_snaps/descriptors.md @@ -0,0 +1,8 @@ +# can be temporarily overriden at evaluation time + + Code + .cols() + Condition + Error in `descr_env$.cols()`: + ! Descriptor context not set + diff --git a/tests/testthat/_snaps/extract.md b/tests/testthat/_snaps/extract.md new file mode 100644 index 000000000..5dddb6af5 --- /dev/null +++ b/tests/testthat/_snaps/extract.md @@ -0,0 +1,56 @@ +# extract + + Code + extract_spec_parsnip(x_no_spec) + Condition + Error in `extract_spec_parsnip()`: + ! The model fit does not have a model spec. + i This is an internal error that was detected in the parsnip package. + Please report it at with a reprex () and the full backtrace. + +--- + + Code + extract_fit_engine(x_no_fit) + Condition + Error in `extract_fit_engine()`: + ! The model fit does not have an engine fit. + i This is an internal error that was detected in the parsnip package. + Please report it at with a reprex () and the full backtrace. + +# extract parameter set from model with no loaded implementation + + Code + extract_parameter_set_dials(bt_mod) + Condition + Error: + ! parsnip could not locate an implementation for `bag_tree` regression model specifications. + i The parsnip extension package baguette implements support for this specification. + i Please install (if needed) and load to continue. + +--- + + Code + extract_parameter_dials(bt_mod, parameter = "min_n") + Condition + Error: + ! parsnip could not locate an implementation for `bag_tree` regression model specifications. + i The parsnip extension package baguette implements support for this specification. + i Please install (if needed) and load to continue. + +# extract single parameter from model with no parameters + + Code + extract_parameter_dials(lm_model, parameter = "none there") + Condition + Error in `extract_parameter_dials()`: + ! No parameter exists with id "none there". + +# extract_fit_time() works + + Code + extract_fit_time(lm_fit) + Condition + Error in `extract_fit_time()`: + ! This model was fit before `extract_fit_time()` was added. + diff --git a/tests/testthat/_snaps/failed_models.md b/tests/testthat/_snaps/failed_models.md new file mode 100644 index 000000000..3c12f0feb --- /dev/null +++ b/tests/testthat/_snaps/failed_models.md @@ -0,0 +1,51 @@ +# numeric model + + Code + num_res <- predict(lm_mod, hpc_bad[1:11, -1]) + Condition + Warning: + Model fit failed; cannot make predictions. + +--- + + Code + ci_res <- predict(lm_mod, hpc_bad[1:11, -1], type = "conf_int") + Condition + Warning: + Model fit failed; cannot make predictions. + +--- + + Code + pi_res <- predict(lm_mod, hpc_bad[1:11, -1], type = "pred_int") + Condition + Warning: + Model fit failed; cannot make predictions. + +# classification model + + Code + cls_res <- predict(log_reg, dplyr::select(dplyr::slice(lending_club, 1:7), + -Class)) + Condition + Warning: + Model fit failed; cannot make predictions. + +--- + + Code + prb_res <- predict(log_reg, dplyr::select(dplyr::slice(lending_club, 1:7), + -Class), type = "prob") + Condition + Warning: + Model fit failed; cannot make predictions. + +--- + + Code + ci_res <- predict(log_reg, dplyr::select(dplyr::slice(lending_club, 1:7), + -Class), type = "conf_int") + Condition + Warning: + Model fit failed; cannot make predictions. + diff --git a/tests/testthat/_snaps/fit_interfaces.md b/tests/testthat/_snaps/fit_interfaces.md new file mode 100644 index 000000000..15f131073 --- /dev/null +++ b/tests/testthat/_snaps/fit_interfaces.md @@ -0,0 +1,78 @@ +# wrong args + + Code + tester_xy(NULL, x = sprk, y = hpc, model = rmod) + Condition + Error in `tester_xy()`: + ! `x` should be a , not an object. + +--- + + Code + tester(NULL, f, data = as.matrix(hpc[, 1:4])) + Condition + Error in `tester()`: + ! `data` should be a , not a double matrix. + +# unknown modes + + Code + fit(mars_spec, am ~ ., data = mtcars) + Condition + Error in `fit()`: + ! Please set the mode in the model specification (`?parsnip::model_spec()`). + +--- + + Code + fit_xy(mars_spec, x = mtcars[, -1], y = mtcars[, 1]) + Condition + Error in `fit_xy()`: + ! Please set the mode in the model specification (`?parsnip::model_spec()`). + +--- + + Code + fit_xy(mars_spec, x = lending_club[, 1:2], y = lending_club$Class) + Condition + Error in `fit_xy()`: + ! Please set the mode in the model specification (`?parsnip::model_spec()`). + +# misspecified formula argument + + Code + fit(linear_reg(), rec, mtcars) + Condition + Error in `fit()`: + ! The `formula` argument must be a formula. + i To fit a model with a recipe preprocessor, please use a workflow (`?workflows::workflow()`). + +--- + + Code + fit(linear_reg(), "boop", mtcars) + Condition + Error in `fit()`: + ! `formula` must be a formula, not the string "boop". + +# No loaded engines + + ! parsnip could not locate an implementation for `cubist_rules` model specifications. + i The parsnip extension package rules implements support for this specification. + i Please install (if needed) and load to continue. + + +--- + + ! parsnip could not locate an implementation for `poisson_reg` model specifications. + i The parsnip extension packages multilevelmod, poissonreg, and agua implement support for this specification. + i Please install (if needed) and load to continue. + + +--- + + ! parsnip could not locate an implementation for `cubist_rules` model specifications using the `Cubist` engine. + i The parsnip extension package rules implements support for this specification. + i Please install (if needed) and load to continue. + + diff --git a/tests/testthat/_snaps/gen_additive_model.md b/tests/testthat/_snaps/gen_additive_model.md new file mode 100644 index 000000000..c79a4c861 --- /dev/null +++ b/tests/testthat/_snaps/gen_additive_model.md @@ -0,0 +1,19 @@ +# regression + + Code + xy_res <- fit_xy(reg_mod, x = mtcars[, 1:5], y = mtcars$mpg, control = ctrl) + Condition + Error in `fit_xy()`: + ! Please use `fit()` rather than `fit_xy()` to train generalized additive models with the "mgcv" engine. + i See `?model_formula()` to learn more. + +# classification + + Code + xy_res <- fit_xy(cls_mod, x = two_class_dat[, 2:3], y = two_class_dat$Class, + control = ctrl) + Condition + Error in `fit_xy()`: + ! Please use `fit()` rather than `fit_xy()` to train generalized additive models with the "mgcv" engine. + i See `?model_formula()` to learn more. + diff --git a/tests/testthat/_snaps/linear_reg.md b/tests/testthat/_snaps/linear_reg.md new file mode 100644 index 000000000..74aa5b25e --- /dev/null +++ b/tests/testthat/_snaps/linear_reg.md @@ -0,0 +1,233 @@ +# updating + + Code + update(set_engine(linear_reg(mixture = 0), "glmnet", nlambda = 10), mixture = tune(), + nlambda = tune()) + Output + Linear Regression Model Specification (regression) + + Main Arguments: + mixture = tune() + + Engine-Specific Arguments: + nlambda = tune() + + Computational engine: glmnet + + +# bad input + + Code + linear_reg(mode = "classification") + Condition + Error in `linear_reg()`: + ! "classification" is not a known mode for model `linear_reg()`. + +--- + + Code + translate(linear_reg(), engine = "wat?") + Condition + Error in `translate.default()`: + x Engine "wat?" is not supported for `linear_reg()` + i See `show_engines("linear_reg")`. + +--- + + Code + translate(linear_reg(), engine = NULL) + Condition + Error in `translate.default()`: + ! Please set an engine. + +# lm execution + + Code + res <- fit_xy(hpc_basic, x = hpc[, num_pred], y = hpc$class, control = ctrl) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be , not a object. + +--- + + Code + res <- fit_xy(hpc_basic, x = hpc[, num_pred], y = as.character(hpc$class), + control = ctrl) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be , not a character vector. + +--- + + Code + res <- fit(hpc_basic, hpc_bad_form, data = hpc, control = ctrl) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be , not a object. + +--- + + Code + lm_form_catch <- fit(hpc_basic, hpc_bad_form, data = hpc, control = caught_ctrl) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be , not a object. + +# glm execution + + Code + res <- fit_xy(hpc_glm, x = hpc[, num_pred], y = hpc$class, control = ctrl) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be , not a object. + +--- + + Code + res <- fit(hpc_glm, hpc_bad_form, data = hpc, control = ctrl) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be , not a object. + +--- + + Code + lm_form_catch <- fit(hpc_glm, hpc_bad_form, data = hpc, control = caught_ctrl) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be , not a object. + +# newdata error trapping + + Code + predict(res_xy, newdata = hpc[1:3, num_pred]) + Condition + Error in `predict()`: + ! Please use `new_data` instead of `newdata`. + +# show engine + + Code + show_engines("linear_re") + Condition + Error in `show_engines()`: + ! No results found for model function "x". + +# lm can handle rankdeficient predictions + + Code + preds <- predict(fit(linear_reg(), y ~ ., data = data), new_data = data2) + Condition + Warning in `predict.lm()`: + prediction from rank-deficient fit; consider predict(., rankdeficient="NA") + +# check_args() works + + Code + spec <- set_mode(set_engine(linear_reg(mixture = -1), "lm"), "regression") + fit(spec, compounds ~ ., hpc) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- set_mode(set_engine(linear_reg(penalty = -1), "lm"), "regression") + fit(spec, compounds ~ ., hpc) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + +# prevent using a Poisson family + + Code + fit(set_engine(linear_reg(penalty = 1), "glmnet", family = poisson), mpg ~ ., + data = mtcars) + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + +--- + + Code + fit(set_engine(linear_reg(penalty = 1), "glmnet", family = stats::poisson), + mpg ~ ., data = mtcars) + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + +--- + + Code + fit(set_engine(linear_reg(penalty = 1), "glmnet", family = stats::poisson()), + mpg ~ ., data = mtcars) + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + +--- + + Code + fit(set_engine(linear_reg(penalty = 1), "glmnet", family = "poisson"), mpg ~ ., + data = mtcars) + Condition + Error in `linear_reg()`: + ! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package. + +# tunables + + Code + tunable(linear_reg()) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(set_engine(linear_reg(), "brulee")) + Output + # A tibble: 8 x 5 + name call_info source component component_id + + 1 epochs model_spec linear_reg engine + 2 penalty model_spec linear_reg main + 3 mixture model_spec linear_reg main + 4 learn_rate model_spec linear_reg engine + 5 momentum model_spec linear_reg engine + 6 batch_size model_spec linear_reg engine + 7 stop_iter model_spec linear_reg engine + 8 rate_schedule model_spec linear_reg engine + +--- + + Code + tunable(set_engine(linear_reg(), "glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + 2 mixture model_spec linear_reg main + +--- + + Code + tunable(set_engine(linear_reg(), "quantreg")) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(set_engine(linear_reg(), "keras")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec linear_reg main + diff --git a/tests/testthat/_snaps/linear_reg_quantreg.md b/tests/testthat/_snaps/linear_reg_quantreg.md new file mode 100644 index 000000000..11fbd80e2 --- /dev/null +++ b/tests/testthat/_snaps/linear_reg_quantreg.md @@ -0,0 +1,9 @@ +# linear quantile regression via quantreg - multiple quantiles + + Code + ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0: + 9) / 9) + Condition + Error in `predict()`: + ! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`. + diff --git a/tests/testthat/_snaps/logistic_reg.md b/tests/testthat/_snaps/logistic_reg.md new file mode 100644 index 000000000..9d9368afd --- /dev/null +++ b/tests/testthat/_snaps/logistic_reg.md @@ -0,0 +1,184 @@ +# updating + + Code + update(set_engine(logistic_reg(mixture = 0), "glmnet", nlambda = 10), mixture = tune(), + nlambda = tune()) + Output + Logistic Regression Model Specification (classification) + + Main Arguments: + mixture = tune() + + Engine-Specific Arguments: + nlambda = tune() + + Computational engine: glmnet + + +# bad input + + Code + logistic_reg(mode = "regression") + Condition + Error in `logistic_reg()`: + ! "regression" is not a known mode for model `logistic_reg()`. + +--- + + Code + translate(set_engine(logistic_reg(mixture = 0.5), engine = "LiblineaR")) + Condition + Error in `translate()`: + ! For the LiblineaR engine, `mixture` must be 0 or 1. + +--- + + Code + res <- fit(dplyr::mutate(mtcars, cyl = as.factor(cyl)), logistic_reg(), cyl ~ + mpg, data = .) + Condition + Error in `UseMethod()`: + ! no applicable method for 'fit' applied to an object of class "data.frame" + +# glm execution + + Code + res <- fit(lc_basic, funded_amnt ~ term, data = lending_club, control = ctrl) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not an integer vector. + +--- + + Code + glm_form_catch <- fit(lc_basic, funded_amnt ~ term, data = lending_club, + control = caught_ctrl) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not an integer vector. + +--- + + Code + glm_xy_catch <- fit_xy(lc_basic, control = caught_ctrl, x = lending_club[, + num_pred], y = lending_club$total_bal_il) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not an integer vector. + +# liblinear execution + + Code + res <- fit(ll_basic, funded_amnt ~ term, data = lending_club, control = ctrl) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not an integer vector. + +--- + + Code + glm_form_catch <- fit(ll_basic, funded_amnt ~ term, data = lending_club, + control = caught_ctrl) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not an integer vector. + +--- + + Code + glm_xy_catch <- fit_xy(ll_basic, control = caught_ctrl, x = lending_club[, + num_pred], y = lending_club$total_bal_il) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not an integer vector. + +# check_args() works + + Code + spec <- set_mode(set_engine(logistic_reg(mixture = -1), "glm"), + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- set_mode(set_engine(logistic_reg(penalty = -1), "glm"), + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- set_mode(set_engine(logistic_reg(mixture = 0.5), "LiblineaR"), + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + x For the LiblineaR engine, mixture must be 0 or 1, not 0.5. + i Choose a pure ridge model with `mixture = 0` or a pure lasso model with `mixture = 1`. + ! The Liblinear engine does not support other values. + +--- + + Code + spec <- set_mode(set_engine(logistic_reg(penalty = 0), "LiblineaR"), + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! For the LiblineaR engine, `penalty` must be `> 0`, not 0. + +# tunables + + Code + tunable(logistic_reg()) + Output + # A tibble: 0 x 5 + # i 5 variables: name , call_info , source , component , + # component_id + +--- + + Code + tunable(set_engine(logistic_reg(), "brulee")) + Output + # A tibble: 9 x 5 + name call_info source component component_id + + 1 epochs model_spec logistic_reg engine + 2 penalty model_spec logistic_reg main + 3 mixture model_spec logistic_reg main + 4 learn_rate model_spec logistic_reg engine + 5 momentum model_spec logistic_reg engine + 6 batch_size model_spec logistic_reg engine + 7 class_weights model_spec logistic_reg engine + 8 stop_iter model_spec logistic_reg engine + 9 rate_schedule model_spec logistic_reg engine + +--- + + Code + tunable(set_engine(logistic_reg(), "glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + 2 mixture model_spec logistic_reg main + +--- + + Code + tunable(set_engine(logistic_reg(), "keras")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec logistic_reg main + diff --git a/tests/testthat/_snaps/mars.md b/tests/testthat/_snaps/mars.md new file mode 100644 index 000000000..dc8775227 --- /dev/null +++ b/tests/testthat/_snaps/mars.md @@ -0,0 +1,68 @@ +# updating + + Code + update(expr1, num_terms = tune(), nk = tune()) + Output + MARS Model Specification (unknown mode) + + Main Arguments: + num_terms = tune() + + Engine-Specific Arguments: + nk = tune() + + Computational engine: earth + + +# bad input + + Code + translate(set_engine(mars(mode = "regression"))) + Condition + Error in `set_engine()`: + ! Missing engine. Possible mode/engine combinations are: classification {earth} and regression {earth}. + +--- + + Code + translate(set_engine(mars(), "wat?")) + Condition + Error in `set_engine()`: + x Engine "wat?" is not supported for `mars()` + i See `show_engines("mars")`. + +# submodel prediction + + Code + multi_predict(reg_fit, newdata = mtcars[1:4, -1], num_terms = 5) + Condition + Error in `multi_predict()`: + ! Please use `new_data` instead of `newdata`. + +# check_args() works + + Code + spec <- set_mode(set_engine(mars(prod_degree = 0), "earth"), "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `prod_degree` must be a whole number larger than or equal to 1 or `NULL`, not the number 0. + +--- + + Code + spec <- set_mode(set_engine(mars(num_terms = 0), "earth"), "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `num_terms` must be a whole number larger than or equal to 1 or `NULL`, not the number 0. + +--- + + Code + spec <- set_mode(set_engine(mars(prune_method = 2), "earth"), "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `prune_method` must be a single string or `NULL`, not the number 2. + diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md new file mode 100644 index 000000000..b1c6b05ef --- /dev/null +++ b/tests/testthat/_snaps/misc.md @@ -0,0 +1,257 @@ +# parsnip objects + + Code + predict(lm_idea, mtcars) + Condition + Error in `predict()`: + ! You must `fit()` your model specification (`?parsnip::model_spec()`) before you can use `predict()`. + +--- + + Code + multi_predict(lm_fit, mtcars) + Condition + Error in `multi_predict()`: + ! No `multi_predict()` method exists for objects with classes <_lm/model_fit>. + +--- + + Code + multi_predict(extract_fit_engine(mars_fit), mtcars) + Condition + Error in `multi_predict()`: + ! No `multi_predict()` method exists for objects with classes . + +# combine_words helper works + + Code + combine_words(1) + Output + 1 + +--- + + Code + combine_words(1:2) + Output + 1 and 2 + +--- + + Code + combine_words(1:3) + Output + 1, 2, and 3 + +--- + + Code + combine_words(1:4) + Output + 1, 2, 3, and 4 + +# model type functions message informatively with unknown implementation + + Code + set_mode(set_engine(bag_tree(), "rpart"), "regression") + Message + ! parsnip could not locate an implementation for `bag_tree` regression model specifications using the `rpart` engine. + i The parsnip extension package baguette implements support for this specification. + i Please install (if needed) and load to continue. + Output + Bagged Decision Tree Model Specification (regression) + + Main Arguments: + cost_complexity = 0 + min_n = 2 + + Computational engine: rpart + + +--- + + Code + set_mode(bag_tree(), "censored regression") + Message + ! parsnip could not locate an implementation for `bag_tree` censored regression model specifications. + i The parsnip extension package censored implements support for this specification. + i Please install (if needed) and load to continue. + Output + Bagged Decision Tree Model Specification (censored regression) + + Main Arguments: + cost_complexity = 0 + min_n = 2 + + Computational engine: rpart + + +--- + + Code + bag_tree() + Message + ! parsnip could not locate an implementation for `bag_tree` model specifications. + i The parsnip extension packages censored and baguette implement support for this specification. + i Please install (if needed) and load to continue. + Output + Bagged Decision Tree Model Specification (unknown mode) + + Main Arguments: + cost_complexity = 0 + min_n = 2 + + Computational engine: rpart + + +--- + + Code + set_engine(bag_tree(), "rpart") + Message + ! parsnip could not locate an implementation for `bag_tree` model specifications using the `rpart` engine. + i The parsnip extension packages censored and baguette implement support for this specification. + i Please install (if needed) and load to continue. + Output + Bagged Decision Tree Model Specification (unknown mode) + + Main Arguments: + cost_complexity = 0 + min_n = 2 + + Computational engine: rpart + + +# missing implementation checks prompt conservatively with old objects + + Code + bt + Message + ! parsnip could not locate an implementation for `bag_tree` model specifications. + i The parsnip extension packages censored and baguette implement support for this specification. + i Please install (if needed) and load to continue. + Output + Bagged Decision Tree Model Specification (regression) + + Main Arguments: + cost_complexity = 0 + min_n = 2 + + Computational engine: rpart + + +# set_engine works as a generic + + Code + set_engine(mtcars, "rpart") + Condition + Error in `set_engine()`: + ! `set_engine()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. + +# check_for_newdata points out correct context + + Code + fn(newdata = "boop!") + Condition + Error in `fn()`: + ! Please use `new_data` instead of `newdata`. + +# check_outcome works as expected + + Code + check_outcome(NULL, reg_spec) + Condition + Error: + ! `linear_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + check_outcome(tibble::new_tibble(list(), nrow = 10), reg_spec) + Condition + Error: + ! `linear_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + fit(reg_spec, ~mpg, mtcars) + Condition + Error: + ! `linear_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + fit_xy(reg_spec, data.frame(x = 1:5), y = NULL) + Condition + Error: + ! `linear_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + check_outcome(NULL, class_spec) + Condition + Error: + ! `logistic_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + check_outcome(tibble::new_tibble(list(), nrow = 10), class_spec) + Condition + Error: + ! `logistic_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + fit(class_spec, ~mpg, mtcars) + Condition + Error: + ! `logistic_reg()` was unable to find an outcome. + i Ensure that you have specified an outcome column and that it hasn't been removed in pre-processing. + +--- + + Code + check_outcome(1:2, cens_spec) + Condition + Error in `check_outcome()`: + ! For a censored regression model, the outcome should be a object, not an integer vector. + +# obtaining prediction columns + + Code + .get_prediction_column_names(1) + Condition + Error in `.get_prediction_column_names()`: + ! `x` should be an object with class or , not a number. + +--- + + Code + .get_prediction_column_names(unk_fit) + Condition + Error in `.get_prediction_column_names()`: + ! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded? + +# register local models + + Code + translate(my_model(), "my_engine") + Output + my model Model Specification (regression) + + Computational engine: my_engine + + Model fit template: + my_model_fun(formula = missing_arg(), data = missing_arg()) + diff --git a/tests/testthat/_snaps/mlp.md b/tests/testthat/_snaps/mlp.md new file mode 100644 index 000000000..c8ff852db --- /dev/null +++ b/tests/testthat/_snaps/mlp.md @@ -0,0 +1,132 @@ +# updating + + Code + update(set_engine(mlp(mode = "classification", hidden_units = 2), "nnet", Hess = FALSE), + hidden_units = tune(), Hess = tune()) + Output + Single Layer Neural Network Model Specification (classification) + + Main Arguments: + hidden_units = tune() + + Engine-Specific Arguments: + Hess = tune() + + Computational engine: nnet + + +# bad input + + Code + mlp(mode = "time series") + Condition + Error in `mlp()`: + ! "time series" is not a known mode for model `mlp()`. + +--- + + Code + translate(set_engine(mlp(mode = "classification"), "wat?")) + Condition + Error in `set_engine()`: + x Engine "wat?" is not supported for `mlp()` + i See `show_engines("mlp")`. + +# check_args() works + + Code + spec <- set_mode(set_engine(mlp(penalty = -1), "keras"), "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- set_mode(set_engine(mlp(dropout = -1), "keras"), "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `dropout` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- set_mode(set_engine(mlp(dropout = 1, penalty = 3), "keras"), + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! Both weight decay and dropout should not be specified. + +# tunables + + Code + tunable(set_engine(mlp(), "brulee")) + Output + # A tibble: 12 x 5 + name call_info source component component_id + + 1 epochs model_spec mlp main + 2 hidden_units model_spec mlp main + 3 activation model_spec mlp main + 4 penalty model_spec mlp main + 5 mixture model_spec mlp engine + 6 dropout model_spec mlp main + 7 learn_rate model_spec mlp main + 8 momentum model_spec mlp engine + 9 batch_size model_spec mlp engine + 10 class_weights model_spec mlp engine + 11 stop_iter model_spec mlp engine + 12 rate_schedule model_spec mlp engine + +--- + + Code + tunable(set_engine(mlp(), "brulee_two_layer")) + Output + # A tibble: 14 x 5 + name call_info source component component_id + + 1 epochs model_spec mlp main + 2 hidden_units model_spec mlp main + 3 hidden_units_2 model_spec mlp engine + 4 activation model_spec mlp main + 5 activation_2 model_spec mlp engine + 6 penalty model_spec mlp main + 7 mixture model_spec mlp engine + 8 dropout model_spec mlp main + 9 learn_rate model_spec mlp main + 10 momentum model_spec mlp engine + 11 batch_size model_spec mlp engine + 12 class_weights model_spec mlp engine + 13 stop_iter model_spec mlp engine + 14 rate_schedule model_spec mlp engine + +--- + + Code + tunable(set_engine(mlp(), "nnet")) + Output + # A tibble: 3 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 epochs model_spec mlp main + +--- + + Code + tunable(set_engine(mlp(), "keras")) + Output + # A tibble: 5 x 5 + name call_info source component component_id + + 1 hidden_units model_spec mlp main + 2 penalty model_spec mlp main + 3 dropout model_spec mlp main + 4 epochs model_spec mlp main + 5 activation model_spec mlp main + diff --git a/tests/testthat/_snaps/mlp_keras.md b/tests/testthat/_snaps/mlp_keras.md new file mode 100644 index 000000000..c2fe8f026 --- /dev/null +++ b/tests/testthat/_snaps/mlp_keras.md @@ -0,0 +1,18 @@ +# keras execution, classification + + Code + res <- parsnip::fit(hpc_keras, class ~ novar, data = hpc, control = ctrl) + Condition + Error: + ! object 'novar' not found + +# all keras activation functions + + Code + parsnip::fit(set_engine(mlp(mode = "classification", hidden_units = 2, penalty = 0.01, + epochs = 2, activation = "invalid"), "keras", verbose = 0), Class ~ A + B, + data = modeldata::two_class_dat) + Condition + Error in `parsnip::keras_mlp()`: + ! `activation` should be one of: elu, exponential, gelu, hardsigmoid, linear, relu, selu, sigmoid, softmax, softplus, softsign, swish, and tanh, not "invalid". + diff --git a/tests/testthat/_snaps/mlp_nnet.md b/tests/testthat/_snaps/mlp_nnet.md new file mode 100644 index 000000000..2a9f5a173 --- /dev/null +++ b/tests/testthat/_snaps/mlp_nnet.md @@ -0,0 +1,8 @@ +# nnet execution, classification + + Code + res <- parsnip::fit(hpc_nnet, class ~ novar, data = hpc, control = ctrl) + Condition + Error: + ! object 'novar' not found + diff --git a/tests/testthat/_snaps/multinom_reg.md b/tests/testthat/_snaps/multinom_reg.md new file mode 100644 index 000000000..604c1adbc --- /dev/null +++ b/tests/testthat/_snaps/multinom_reg.md @@ -0,0 +1,121 @@ +# updating + + Code + update(set_engine(multinom_reg(mixture = 0), "glmnet", nlambda = 10), mixture = tune(), + nlambda = tune()) + Output + Multinomial Regression Model Specification (classification) + + Main Arguments: + mixture = tune() + + Engine-Specific Arguments: + nlambda = tune() + + Computational engine: glmnet + + +# bad input + + Code + multinom_reg(mode = "regression") + Condition + Error in `multinom_reg()`: + ! "regression" is not a known mode for model `multinom_reg()`. + +--- + + Code + translate(set_engine(multinom_reg(penalty = 0.1), "wat?")) + Condition + Error in `set_engine()`: + x Engine "wat?" is not supported for `multinom_reg()` + i See `show_engines("multinom_reg")`. + +--- + + Code + set_engine(multinom_reg(penalty = 0.1)) + Condition + Error in `set_engine()`: + ! Missing engine. Possible mode/engine combinations are: classification {glmnet, spark, keras, nnet, brulee}. + +# check_args() works + + Code + spec <- set_mode(set_engine(multinom_reg(mixture = -1), "keras"), + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- set_mode(set_engine(multinom_reg(penalty = -1), "keras"), + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + +# tunables + + Code + tunable(multinom_reg()) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(set_engine(multinom_reg(), "brulee")) + Output + # A tibble: 9 x 5 + name call_info source component component_id + + 1 epochs model_spec multinom_reg engine + 2 penalty model_spec multinom_reg main + 3 mixture model_spec multinom_reg main + 4 learn_rate model_spec multinom_reg engine + 5 momentum model_spec multinom_reg engine + 6 batch_size model_spec multinom_reg engine + 7 class_weights model_spec multinom_reg engine + 8 stop_iter model_spec multinom_reg engine + 9 rate_schedule model_spec multinom_reg engine + +--- + + Code + tunable(set_engine(multinom_reg(), "nnet")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + +--- + + Code + tunable(set_engine(multinom_reg(), "glmnet")) + Output + # A tibble: 2 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + 2 mixture model_spec multinom_reg main + +--- + + Code + tunable(set_engine(multinom_reg(), "keras")) + Output + # A tibble: 1 x 5 + name call_info source component component_id + + 1 penalty model_spec multinom_reg main + diff --git a/tests/testthat/_snaps/nearest_neighbor_kknn.md b/tests/testthat/_snaps/nearest_neighbor_kknn.md new file mode 100644 index 000000000..22d617419 --- /dev/null +++ b/tests/testthat/_snaps/nearest_neighbor_kknn.md @@ -0,0 +1,26 @@ +# kknn execution + + Code + fit(hpc_basic, hpc_bad_form, data = hpc, control = ctrl) + Condition + Error: + ! object 'term' not found + +# argument checks for data dimensions + + Code + f_fit <- fit(spec, body_mass_g ~ ., data = penguins) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 328 samples will be used. + +--- + + Code + xy_fit <- fit_xy(spec, x = penguins[, -6], y = penguins$body_mass_g) + Condition + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 328 samples will be used. + diff --git a/tests/testthat/_snaps/nullmodel.md b/tests/testthat/_snaps/nullmodel.md new file mode 100644 index 000000000..ee456ea2b --- /dev/null +++ b/tests/testthat/_snaps/nullmodel.md @@ -0,0 +1,48 @@ +# bad input + + Code + translate(set_engine(null_model(mode = "regression"))) + Condition + Error in `set_engine()`: + ! Missing engine. Possible mode/engine combinations are: classification {parsnip} and regression {parsnip}. + +--- + + Code + translate(set_engine(null_model(), "wat?")) + Condition + Error in `set_engine()`: + x Engine "wat?" is not supported for `null_model()` + i See `show_engines("null_model")`. + +# nullmodel execution + + Code + res <- fit(set_engine(null_model(mode = "regression"), "parsnip"), hpc_bad_form, + data = hpc) + Condition + Error: + ! object 'term' not found + +# null_model printing + + Code + print(null_model(mode = "classification")) + Output + Null Model Specification (classification) + + Computational engine: parsnip + + +--- + + Code + print(translate(set_engine(null_model(mode = "classification"), "parsnip"))) + Output + Null Model Specification (classification) + + Computational engine: parsnip + + Model fit template: + parsnip::nullmodel(x = missing_arg(), y = missing_arg()) + diff --git a/tests/testthat/_snaps/predict_formats.md b/tests/testthat/_snaps/predict_formats.md new file mode 100644 index 000000000..44beda1c4 --- /dev/null +++ b/tests/testthat/_snaps/predict_formats.md @@ -0,0 +1,36 @@ +# predict(type = "prob") with level "class" (see #720) + + Code + predict(mod, type = "prob", new_data = x) + Condition + Error in `check_spec_levels()`: + ! The outcome variable `boop` has a level called "class". + i This value is reserved for parsnip's classification internals; please change the levels, perhaps with `forcats::fct_relevel()`. + +# non-factor classification + + Code + fit(set_engine(logistic_reg(), "glm"), class ~ ., data = dplyr::mutate(hpc, + class = class == "VF")) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not a logical vector. + +--- + + Code + fit(set_engine(logistic_reg(), "glm"), class ~ ., data = dplyr::mutate(hpc, + class = ifelse(class == "VF", 1, 0))) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not a double vector. + +--- + + Code + fit(set_engine(multinom_reg(), "glmnet"), class ~ ., data = dplyr::mutate(hpc, + class = as.character(class))) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not a character vector. + diff --git a/tests/testthat/_snaps/rand_forest_ranger.md b/tests/testthat/_snaps/rand_forest_ranger.md new file mode 100644 index 000000000..aa5334b00 --- /dev/null +++ b/tests/testthat/_snaps/rand_forest_ranger.md @@ -0,0 +1,90 @@ +# ranger classification execution + + Code + res <- fit(lc_ranger, funded_amnt ~ Class + term, data = lending_club, control = ctrl) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a , not an integer vector. + +# ranger classification probabilities + + Code + parsnip:::predict_classprob.model_fit(no_prob_model, new_data = lending_club[1: + 6, num_pred]) + Condition + Error in `predict()`: + ! `ranger` model does not appear to use class probabilities. + i Was the model fit with `probability = TRUE`? + +# ranger regression intervals + + Code + rgr_se <- predict(extract_fit_engine(xy_fit), data = head(ames_x, 3), type = "se")$ + se + Condition + Warning in `rInfJack()`: + Sample size <=20, no calibration performed. + Warning in `sqrt()`: + NaNs produced + +--- + + Code + parsnip_int <- predict(xy_fit, new_data = head(ames_x, 3), type = "conf_int", + std_error = TRUE, level = 0.93) + Condition + Warning in `rInfJack()`: + Sample size <=20, no calibration performed. + Warning in `sqrt()`: + NaNs produced + +# ranger classification intervals + + Code + rgr_se <- predict(extract_fit_engine(lc_fit), data = tail(lending_club), type = "se")$ + se + Condition + Warning in `rInfJack()`: + Sample size <=20, no calibration performed. + Warning in `rInfJack()`: + Sample size <=20, no calibration performed. + Warning in `sqrt()`: + NaNs produced + +--- + + Code + parsnip_int <- predict(lc_fit, new_data = tail(lending_club), type = "conf_int", + std_error = TRUE, level = 0.93) + Condition + Warning in `rInfJack()`: + Sample size <=20, no calibration performed. + Warning in `rInfJack()`: + Sample size <=20, no calibration performed. + Warning in `sqrt()`: + NaNs produced + +# argument checks for data dimensions + + Code + f_fit <- fit(spec, body_mass_g ~ ., data = penguins) + Condition + Warning: + ! 1000 columns were requested but there were 6 predictors in the data. + i 6 predictors will be used. + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 samples will be used. + +--- + + Code + xy_fit <- fit_xy(spec, x = penguins[, -6], y = penguins$body_mass_g) + Condition + Warning: + ! 1000 columns were requested but there were 6 predictors in the data. + i 6 predictors will be used. + Warning: + ! 1000 samples were requested but there were 333 rows in the data. + i 333 samples will be used. + diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md new file mode 100644 index 000000000..b7d04e160 --- /dev/null +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -0,0 +1,145 @@ +# sparse tibble can be passed to `fit() - supported + + Code + xgb_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + Condition + Error in `fit()`: + ! Sparse data cannot be used with formula interface. Please use `fit_xy()` instead. + +# sparse tibble can be passed to `fit() - unsupported + + Code + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + Condition + Warning: + `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. + +# sparse matrix can be passed to `fit() - supported + + Code + xgb_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + Condition + Error in `fit()`: + ! Sparse data cannot be used with formula interface. Please use `fit_xy()` instead. + +# sparse matrix can be passed to `fit() - unsupported + + Code + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + Condition + Warning: + `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. + +# sparse tibble can be passed to `fit_xy() - unsupported + + Code + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) + Condition + Warning: + `x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. + +# sparse matrices can be passed to `fit_xy() - unsupported + + Code + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) + Condition + Error in `fit_xy()`: + ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. + +# sparse tibble can be passed to `predict() - unsupported + + Code + preds <- predict(lm_fit, sparse_mtcars) + Condition + Warning: + `x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. + +# sparse matrices can be passed to `predict() - unsupported + + Code + predict(lm_fit, sparse_mtcars) + Condition + Error in `predict()`: + ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. + +# sparse data work with xgboost engine + + Code + xgb_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + Condition + Error in `fit()`: + ! Sparse data cannot be used with formula interface. Please use `fit_xy()` instead. + +# to_sparse_data_frame() is used correctly + + Code + fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1]) + Condition + Error in `to_sparse_data_frame()`: + ! x is not sparse + +--- + + Code + fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + Condition + Error in `to_sparse_data_frame()`: + ! x is spare, and sparse is not allowed + +--- + + Code + fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + Condition + Error in `to_sparse_data_frame()`: + ! x is spare, and sparse is allowed + +# maybe_sparse_matrix() is used correctly + + Code + fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + Condition + Error in `maybe_sparse_matrix()`: + ! sparse vectors detected + +--- + + Code + fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1]) + Condition + Error in `maybe_sparse_matrix()`: + ! no sparse vectors detected + +--- + + Code + fit_xy(spec, x = as.data.frame(mtcars)[, -1], y = as.data.frame(mtcars)[, 1]) + Condition + Error in `maybe_sparse_matrix()`: + ! no sparse vectors detected + +--- + + Code + fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[, + 1]) + Condition + Error in `maybe_sparse_matrix()`: + ! no sparse vectors detected + +# we don't run as.matrix() on sparse matrix for glmnet pred #1210 + + Code + predict(lm_fit, hotel_data) + Condition + Error in `predict.elnet()`: + ! data is sparse + +# fit() errors if sparse matrix has no colnames + + Code + fit(spec, avg_price_per_room ~ ., data = hotel_data) + Condition + Error in `fit()`: + ! `x` must have column names. + diff --git a/tests/testthat/_snaps/svm_linear.md b/tests/testthat/_snaps/svm_linear.md new file mode 100644 index 000000000..288664aa4 --- /dev/null +++ b/tests/testthat/_snaps/svm_linear.md @@ -0,0 +1,65 @@ +# updating + + Code + update(set_engine(svm_linear(mode = "regression", cost = 2), "kernlab", cross = 10), + cross = tune(), cost = tune()) + Output + Linear Support Vector Machine Model Specification (regression) + + Main Arguments: + cost = tune() + + Engine-Specific Arguments: + cross = tune() + + Computational engine: kernlab + + +# bad input + + Code + translate(set_engine(svm_linear(mode = "regression"), NULL)) + Condition + Error in `set_engine()`: + ! Missing engine. Possible mode/engine combinations are: classification {LiblineaR, kernlab} and regression {LiblineaR, kernlab}. + +--- + + Code + svm_linear(mode = "reallyunknown") + Condition + Error in `svm_linear()`: + ! "reallyunknown" is not a known mode for model `svm_linear()`. + +--- + + Code + translate(set_engine(svm_linear(mode = "regression"), "LiblineaR", type = 3)) + Condition + Error in `translate()`: + ! The LiblineaR engine argument `type = 3` does not correspond to an SVM regression model. + +--- + + Code + translate(set_engine(svm_linear(mode = "classification"), "LiblineaR", type = 11)) + Condition + Error in `translate()`: + ! The LiblineaR engine argument of `type = 11` does not correspond to an SVM classification model. + +# linear svm classification prediction: LiblineaR + + Code + predict(cls_form, hpc_no_m[ind, -5], type = "prob") + Condition + Error in `predict()`: + ! No "prob" prediction method available for this model. `type` should be one of: "class" and "raw". + +--- + + Code + predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob") + Condition + Error in `predict()`: + ! No "prob" prediction method available for this model. `type` should be one of: "class" and "raw". + diff --git a/tests/testthat/_snaps/svm_poly.md b/tests/testthat/_snaps/svm_poly.md new file mode 100644 index 000000000..b3c007f13 --- /dev/null +++ b/tests/testthat/_snaps/svm_poly.md @@ -0,0 +1,33 @@ +# updating + + Code + update(set_engine(svm_poly(mode = "regression", degree = 2), "kernlab", cross = 10), + degree = tune(), cross = tune()) + Output + Polynomial Support Vector Machine Model Specification (regression) + + Main Arguments: + degree = tune() + + Engine-Specific Arguments: + cross = tune() + + Computational engine: kernlab + + +# bad input + + Code + svm_poly(mode = "reallyunknown") + Condition + Error in `svm_poly()`: + ! "reallyunknown" is not a known mode for model `svm_poly()`. + +--- + + Code + translate(set_engine(svm_poly(), NULL)) + Condition + Error in `set_engine()`: + ! Missing engine. Possible mode/engine combinations are: classification {kernlab} and regression {kernlab}. + diff --git a/tests/testthat/_snaps/svm_rbf.md b/tests/testthat/_snaps/svm_rbf.md new file mode 100644 index 000000000..321bb5694 --- /dev/null +++ b/tests/testthat/_snaps/svm_rbf.md @@ -0,0 +1,53 @@ +# engine arguments + + Code + translate(kernlab_cv, "kernlab")$method$fit$args + Output + $x + missing_arg() + + $data + missing_arg() + + $cross + + expr: ^10 + env: empty + + $kernel + [1] "rbfdot" + + +# updating + + Code + update(set_engine(svm_rbf(mode = "regression", rbf_sigma = 0.3), "kernlab", + cross = 10), rbf_sigma = tune(), cross = tune()) + Output + Radial Basis Function Support Vector Machine Model Specification (regression) + + Main Arguments: + rbf_sigma = tune() + + Engine-Specific Arguments: + cross = tune() + + Computational engine: kernlab + + +# bad input + + Code + svm_rbf(mode = "reallyunknown") + Condition + Error in `svm_rbf()`: + ! "reallyunknown" is not a known mode for model `svm_rbf()`. + +--- + + Code + translate(set_engine(svm_rbf(mode = "regression"), NULL)) + Condition + Error in `set_engine()`: + ! Missing engine. Possible mode/engine combinations are: classification {kernlab, liquidSVM} and regression {kernlab, liquidSVM}. + diff --git a/tests/testthat/test-rand_forest.R b/tests/testthat/test-rand_forest.R index 15d4619ac..a93de445b 100644 --- a/tests/testthat/test-rand_forest.R +++ b/tests/testthat/test-rand_forest.R @@ -1,4 +1,3 @@ - test_that('updating', { expect_snapshot( rand_forest(mode = "regression", mtry = 2) |> @@ -8,12 +7,19 @@ test_that('updating', { }) test_that('bad input', { - expect_snapshot(res <- - translate(rand_forest(mode = "classification") |> - set_engine(NULL)), - error = TRUE) + expect_snapshot( + res <- + translate( + rand_forest(mode = "classification") |> + set_engine(NULL) + ), + error = TRUE + ) expect_snapshot(error = TRUE, rand_forest(mode = "time series")) - expect_snapshot(error = TRUE, translate(rand_forest(mode = "classification") |> set_engine("wat?"))) + expect_snapshot( + error = TRUE, + translate(rand_forest(mode = "classification") |> set_engine("wat?")) + ) }) test_that("check_args() works", { From 20b7bd296e7ef5d305d169251cd051998a051e3c Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 22 Oct 2025 14:08:48 -0400 Subject: [PATCH 06/15] fix typo --- man/details_rand_forest_grf.Rd | 6 +++--- man/rmd/rand_forest_grf.Rmd | 6 +++--- man/rmd/rand_forest_grf.md | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/man/details_rand_forest_grf.Rd b/man/details_rand_forest_grf.Rd index 081cd05c8..85b46767d 100644 --- a/man/details_rand_forest_grf.Rd +++ b/man/details_rand_forest_grf.Rd @@ -28,7 +28,7 @@ the default value of \code{mtry} is \code{min(ceiling(sqrt(p) + 20), p)}. \subsection{Translation from parsnip to the original package (regression)}{ See -\href{\%22https://grf-labs.github.io/grf/reference/regression_forest.html}{\code{?regression_forest}} +\href{https://grf-labs.github.io/grf/reference/regression_forest.html}{\code{?regression_forest}} \if{html}{\out{
}}\preformatted{rand_forest( mtry = integer(1), @@ -60,7 +60,7 @@ See \subsection{Translation from parsnip to the original package (classification)}{ See -\href{\%22https://grf-labs.github.io/grf/reference/probability_forest.html}{\code{?probability_forest}} +\href{https://grf-labs.github.io/grf/reference/probability_forest.html}{\code{?probability_forest}} \if{html}{\out{
}}\preformatted{rand_forest( mtry = integer(1), @@ -92,7 +92,7 @@ See \subsection{Translation from parsnip to the original package (quantile regression)}{ See -\href{\%22https://grf-labs.github.io/grf/reference/quantile_forest.html}{\code{?quantile_forest}} +\href{https://grf-labs.github.io/grf/reference/quantile_forest.html}{\code{?quantile_forest}} When specifying \emph{any} quantile regression model, the user must specify the quantile levels \emph{a priori}. diff --git a/man/rmd/rand_forest_grf.Rmd b/man/rmd/rand_forest_grf.Rmd index 05f26ffe4..6422dd5f7 100644 --- a/man/rmd/rand_forest_grf.Rmd +++ b/man/rmd/rand_forest_grf.Rmd @@ -33,7 +33,7 @@ param$item ## Translation from parsnip to the original package (regression) -See [`?regression_forest`]("https://grf-labs.github.io/grf/reference/regression_forest.html) +See [`?regression_forest`](https://grf-labs.github.io/grf/reference/regression_forest.html) ```{r} #| label: grf-reg @@ -49,7 +49,7 @@ rand_forest( ## Translation from parsnip to the original package (classification) -See [`?probability_forest`]("https://grf-labs.github.io/grf/reference/probability_forest.html) +See [`?probability_forest`](https://grf-labs.github.io/grf/reference/probability_forest.html) ```{r} #| label: grf-cls @@ -65,7 +65,7 @@ rand_forest( ## Translation from parsnip to the original package (quantile regression) -See [`?quantile_forest`]("https://grf-labs.github.io/grf/reference/quantile_forest.html) +See [`?quantile_forest`](https://grf-labs.github.io/grf/reference/quantile_forest.html) When specifying _any_ quantile regression model, the user must specify the quantile levels _a priori_. diff --git a/man/rmd/rand_forest_grf.md b/man/rmd/rand_forest_grf.md index 307c3246f..5c00b9b90 100644 --- a/man/rmd/rand_forest_grf.md +++ b/man/rmd/rand_forest_grf.md @@ -19,7 +19,7 @@ This model has 3 tuning parameters: ## Translation from parsnip to the original package (regression) -See [`?regression_forest`]("https://grf-labs.github.io/grf/reference/regression_forest.html) +See [`?regression_forest`](https://grf-labs.github.io/grf/reference/regression_forest.html) ``` r @@ -52,7 +52,7 @@ rand_forest( ## Translation from parsnip to the original package (classification) -See [`?probability_forest`]("https://grf-labs.github.io/grf/reference/probability_forest.html) +See [`?probability_forest`](https://grf-labs.github.io/grf/reference/probability_forest.html) ``` r @@ -85,7 +85,7 @@ rand_forest( ## Translation from parsnip to the original package (quantile regression) -See [`?quantile_forest`]("https://grf-labs.github.io/grf/reference/quantile_forest.html) +See [`?quantile_forest`](https://grf-labs.github.io/grf/reference/quantile_forest.html) When specifying _any_ quantile regression model, the user must specify the quantile levels _a priori_. From 3facc66b7d392cbc770de3603a8430b590a7f652 Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 27 Oct 2025 17:04:01 -0400 Subject: [PATCH 07/15] air formatting --- R/fit.R | 139 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 71 insertions(+), 68 deletions(-) diff --git a/R/fit.R b/R/fit.R index 4a62a1ff7..287a3223b 100644 --- a/R/fit.R +++ b/R/fit.R @@ -109,12 +109,13 @@ #' @export #' @export fit.model_spec fit.model_spec <- - function(object, - formula, - data, - case_weights = NULL, - control = control_parsnip(), - ... + function( + object, + formula, + data, + case_weights = NULL, + control = control_parsnip(), + ... ) { if (object$mode == "unknown") { cli::cli_abort( @@ -135,7 +136,6 @@ fit.model_spec <- } check_formula(formula) - if (is_sparse_matrix(data)) { data <- sparsevctrs::coerce_to_sparse_tibble(data, rlang::caller_env(0)) } @@ -153,12 +153,14 @@ fit.model_spec <- eng_vals <- possible_engines(object) object$engine <- eng_vals[1] if (control$verbosity > 0) { - cli::cli_warn("Engine set to {.val {object$engine}}.") + cli::cli_warn("Engine set to {.val {object$engine}}.") } } if (all(c("x", "y") %in% names(dots))) { - cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.") + cli::cli_abort( + "{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead." + ) } cl <- match.call(expand.dots = TRUE) # Create an environment with the evaluated argument objects. This will be @@ -186,11 +188,12 @@ fit.model_spec <- fit_interface <- check_interface(eval_env$formula, eval_env$data, cl, object) - if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) + if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) { cli::cli_abort( - "spark objects can only be used with the formula interface to {.fn fit} + "spark objects can only be used with the formula interface to {.fn fit} with a spark data object." - ) + ) + } # populate `method` with the details for this model type object <- add_methods(object, engine = object$engine) @@ -208,30 +211,27 @@ fit.model_spec <- switch( interfaces, # homogeneous combinations: - formula_formula = - form_form( - object = object, - control = control, - env = eval_env - ), + formula_formula = form_form( + object = object, + control = control, + env = eval_env + ), # heterogenous combinations - formula_matrix = - form_xy( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), - formula_data.frame = - form_xy( - object = object, - control = control, - env = eval_env, - target = object$method$fit$interface, - ... - ), + formula_matrix = form_xy( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), + formula_data.frame = form_xy( + object = object, + control = control, + env = eval_env, + target = object$method$fit$interface, + ... + ), cli::cli_abort("{.val {interfaces}} is unknown.") ) @@ -239,7 +239,7 @@ fit.model_spec <- model_classes <- class(res$fit) class(res) <- c(paste0("_", model_classes[1]), "model_fit") res -} + } # ------------------------------------------------------------------------------ @@ -247,12 +247,13 @@ fit.model_spec <- #' @export #' @export fit_xy.model_spec fit_xy.model_spec <- - function(object, - x, - y, - case_weights = NULL, - control = control_parsnip(), - ... + function( + object, + x, + y, + case_weights = NULL, + control = control_parsnip(), + ... ) { if (object$mode == "unknown") { cli::cli_abort( @@ -329,32 +330,32 @@ fit_xy.model_spec <- switch( interfaces, # homogeneous combinations: - matrix_matrix = , data.frame_matrix = - xy_xy( - object = object, - env = eval_env, - control = control, - target = "matrix", - ... - ), - - data.frame_data.frame = , matrix_data.frame = - xy_xy( - object = object, - env = eval_env, - control = control, - target = "data.frame", - ... - ), + matrix_matrix = , + data.frame_matrix = xy_xy( + object = object, + env = eval_env, + control = control, + target = "matrix", + ... + ), + + data.frame_data.frame = , + matrix_data.frame = xy_xy( + object = object, + env = eval_env, + control = control, + target = "data.frame", + ... + ), # heterogenous combinations - matrix_formula = , data.frame_formula = - xy_form( - object = object, - env = eval_env, - control = control, - ... - ), + matrix_formula = , + data.frame_formula = xy_form( + object = object, + env = eval_env, + control = control, + ... + ), cli::cli_abort("{.val {interfaces}} is unknown.") ) res$censor_probs <- reverse_km(object, eval_env) @@ -368,7 +369,9 @@ fit_xy.model_spec <- eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) { if (capture) { if (catch) { - junk <- capture.output(res <- try(eval_tidy(e, env = envir, ...), silent = TRUE)) + junk <- capture.output( + res <- try(eval_tidy(e, env = envir, ...), silent = TRUE) + ) } else { junk <- capture.output(res <- eval_tidy(e, env = envir, ...)) } @@ -391,13 +394,13 @@ check_interface <- function(formula, data, cl, model, call = caller_env()) { # Determine the `fit()` interface form_interface <- !is.null(formula) & !is.null(data) - if (form_interface) + if (form_interface) { return("formula") + } cli::cli_abort("Error when checking the interface.", call = call) } check_xy_interface <- function(x, y, cl, model, call = caller_env()) { - sparse_ok <- allow_sparse(model) sparse_x <- inherits(x, "dgCMatrix") if (!sparse_ok & sparse_x) { From 2808e7fcacd1e0747db360b9e7ea6ed45bc3e929 Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 27 Oct 2025 17:04:12 -0400 Subject: [PATCH 08/15] fix case weight entries --- R/rand_forest_data.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 6d90081a1..9b695da97 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -727,7 +727,7 @@ set_fit( value = list( interface = "data.frame", data = c(x = "X", y = "Y", weights = "sample.weights"), - protect = c("x", "y", "weights"), + protect = c("X", "Y", "weights"), func = c(pkg = "grf", fun = "probability_forest"), defaults = list( num.threads = 1 @@ -753,8 +753,8 @@ set_fit( mode = "regression", value = list( interface = "data.frame", - data = c(x = "X", y = "Y", weights = "case.weights"), - protect = c("x", "y", "weights"), + data = c(x = "X", y = "Y", weights = "sample.weights"), + protect = c("X", "Y", "weights"), func = c(pkg = "grf", fun = "regression_forest"), defaults = list( num.threads = 1 @@ -860,8 +860,8 @@ set_fit( mode = "quantile regression", value = list( interface = "data.frame", - data = c(x = "X", y = "Y", weights = "case.weights"), - protect = c("x", "y", "weights"), + data = c(x = "X", y = "Y"), + protect = c("X", "Y"), func = c(pkg = "grf", fun = "quantile_forest"), defaults = list( num.threads = 1, From db4f2360e733a4ed09be1a1b1ede7c7da2a703e2 Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 27 Oct 2025 17:04:18 -0400 Subject: [PATCH 09/15] redoc --- DESCRIPTION | 3 ++- man/C5_rules.Rd | 2 +- man/add_rowindex.Rd | 2 +- man/augment.Rd | 2 +- man/bart.Rd | 2 +- man/boost_tree.Rd | 2 +- man/condense_control.Rd | 2 +- man/control_parsnip.Rd | 2 +- man/ctree_train.Rd | 2 +- man/decision_tree.Rd | 2 +- man/doc-tools.Rd | 2 +- man/dot-get_prediction_column_names.Rd | 2 +- man/dot-model_param_name_key.Rd | 2 +- man/extract-parsnip.Rd | 2 +- man/fit.Rd | 2 +- man/fit_control.Rd | 2 +- man/gen_additive_mod.Rd | 2 +- man/get_model_env.Rd | 2 +- man/glm_grouped.Rd | 2 +- man/has_multi_predict.Rd | 2 +- man/linear_reg.Rd | 2 +- man/logistic_reg.Rd | 2 +- man/mars.Rd | 2 +- man/max_mtry_formula.Rd | 2 +- man/min_cols.Rd | 2 +- man/mlp.Rd | 2 +- man/model_db.Rd | 2 +- man/model_fit.Rd | 2 +- man/multinom_reg.Rd | 2 +- man/nearest_neighbor.Rd | 2 +- man/null_model.Rd | 2 +- man/nullmodel.Rd | 2 +- man/parsnip-package.Rd | 2 +- man/parsnip_update.Rd | 10 +++++----- man/predict.model_fit.Rd | 2 +- man/proportional_hazards.Rd | 2 +- man/rand_forest.Rd | 2 +- man/repair_call.Rd | 2 +- man/required_pkgs.model_spec.Rd | 2 +- man/rmd/discrim_linear_sparsediscrim.md | 2 +- man/rmd/discrim_quad_sparsediscrim.md | 2 +- man/rmd/rand_forest_grf.Rmd | 2 ++ man/rmd/rand_forest_grf.md | 13 +++++++------ man/rule_fit.Rd | 2 +- man/set_args.Rd | 2 +- man/set_engine.Rd | 2 +- man/set_new_model.Rd | 2 +- man/show_engines.Rd | 2 +- man/survival_reg.Rd | 2 +- man/svm_linear.Rd | 2 +- man/svm_poly.Rd | 2 +- man/svm_rbf.Rd | 2 +- man/tidy.nullmodel.Rd | 2 +- man/translate.Rd | 2 +- man/varying_args.Rd | 2 +- 55 files changed, 67 insertions(+), 63 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 9fe2a410f..2291795fa 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -59,6 +59,7 @@ Suggests: modeldata, nlme, prodlim, + quantreg, ranger (>= 0.12.0), remotes, rmarkdown, @@ -80,4 +81,4 @@ Config/testthat/edition: 3 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 diff --git a/man/C5_rules.Rd b/man/C5_rules.Rd index 83f60ee69..cd5ce5da8 100644 --- a/man/C5_rules.Rd +++ b/man/C5_rules.Rd @@ -55,7 +55,7 @@ C5_rules(argument = !!value) }\if{html}{\out{
}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("C5_rules") C5_rules() diff --git a/man/add_rowindex.Rd b/man/add_rowindex.Rd index ae2d9e7f4..fa4a5e678 100644 --- a/man/add_rowindex.Rd +++ b/man/add_rowindex.Rd @@ -16,7 +16,7 @@ The same data frame with a column of 1-based integers named \code{.row}. Add a column of row numbers to a data frame } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} mtcars |> add_rowindex() \dontshow{\}) # examplesIf} } diff --git a/man/augment.Rd b/man/augment.Rd index 821c803e2..49ea151df 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -56,7 +56,7 @@ class \code{"quantile_pred"} and can be unnested using \code{\link[tidyr:unnest] } } \examples{ -\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("modeldata")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("modeldata")) withAutoprint(\{ # examplesIf} car_trn <- mtcars[11:32,] car_tst <- mtcars[ 1:10,] diff --git a/man/bart.Rd b/man/bart.Rd index da6cc7335..f7c84a6c3 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -75,7 +75,7 @@ bart(argument = !!value) }\if{html}{\out{
}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("bart") bart(mode = "regression", trees = 5) diff --git a/man/boost_tree.Rd b/man/boost_tree.Rd index a36a4de25..97f4f49f1 100644 --- a/man/boost_tree.Rd +++ b/man/boost_tree.Rd @@ -81,7 +81,7 @@ boost_tree(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("boost_tree") boost_tree(mode = "classification", trees = 20) diff --git a/man/condense_control.Rd b/man/condense_control.Rd index d347bcd3e..e006893ca 100644 --- a/man/condense_control.Rd +++ b/man/condense_control.Rd @@ -27,7 +27,7 @@ throughout the tidymodels packages. It is now assumed that each control function is either a subset or a superset of another control function. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} ctrl <- control_parsnip(catch = TRUE) ctrl$allow_par <- TRUE str(ctrl) diff --git a/man/control_parsnip.Rd b/man/control_parsnip.Rd index a281331c8..b82f6e975 100644 --- a/man/control_parsnip.Rd +++ b/man/control_parsnip.Rd @@ -27,7 +27,7 @@ Pass options to the \code{\link[=fit.model_spec]{fit.model_spec()}} function to output and computations } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} control_parsnip(verbosity = 2L) \dontshow{\}) # examplesIf} } diff --git a/man/ctree_train.Rd b/man/ctree_train.Rd index dc0cf484b..5bc3c858b 100644 --- a/man/ctree_train.Rd +++ b/man/ctree_train.Rd @@ -74,7 +74,7 @@ These functions are slightly different APIs for \code{\link[partykit:ctree]{part arguments (as opposed to being specified in \code{\link[partykit:ctree_control]{partykit::ctree_control()}}). } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} if (rlang::is_installed(c("modeldata", "partykit"))) { data(bivariate, package = "modeldata") ctree_train(Class ~ ., data = bivariate_train) diff --git a/man/decision_tree.Rd b/man/decision_tree.Rd index fbdb31742..456a7d366 100644 --- a/man/decision_tree.Rd +++ b/man/decision_tree.Rd @@ -56,7 +56,7 @@ decision_tree(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("decision_tree") decision_tree(mode = "classification", tree_depth = 5) diff --git a/man/doc-tools.Rd b/man/doc-tools.Rd index 7257c6a3b..8f6ee6364 100644 --- a/man/doc-tools.Rd +++ b/man/doc-tools.Rd @@ -50,7 +50,7 @@ the References section). Most parsnip users will not need to use these functions or documentation. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # See this file for step-by-step instructions. system.file("README-DOCS.md", package = "parsnip") diff --git a/man/dot-get_prediction_column_names.Rd b/man/dot-get_prediction_column_names.Rd index e316107ca..7add7c851 100644 --- a/man/dot-get_prediction_column_names.Rd +++ b/man/dot-get_prediction_column_names.Rd @@ -19,7 +19,7 @@ A list with elements \code{"estimate"} and \code{"probabilities"}. columns for the primary prediction types for a model. } \examples{ -\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("modeldata")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("modeldata")) withAutoprint(\{ # examplesIf} library(dplyr) library(modeldata) data("two_class_dat") diff --git a/man/dot-model_param_name_key.Rd b/man/dot-model_param_name_key.Rd index ea391d5b3..d91632bf9 100644 --- a/man/dot-model_param_name_key.Rd +++ b/man/dot-model_param_name_key.Rd @@ -22,7 +22,7 @@ tuning parameter names, the standardized parsnip parameter names, and the argument names to the underlying fit function for the engine. } \examples{ -\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("dials")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("dials")) withAutoprint(\{ # examplesIf} mod <- linear_reg(penalty = tune("regularization"), mixture = tune()) |> set_engine("glmnet") diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd index 95355f55d..19155db57 100644 --- a/man/extract-parsnip.Rd +++ b/man/extract-parsnip.Rd @@ -70,7 +70,7 @@ or silently generating incorrect predictions. }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("dials")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("dials")) withAutoprint(\{ # examplesIf} lm_spec <- linear_reg() |> set_engine("lm") lm_fit <- fit(lm_spec, mpg ~ ., data = mtcars) diff --git a/man/fit.Rd b/man/fit.Rd index 295312238..67bf23b54 100644 --- a/man/fit.Rd +++ b/man/fit.Rd @@ -110,7 +110,7 @@ Sparse data is supported, with the use of the \code{x} argument in \code{fit_xy( compatibility. } \examples{ -\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("modeldata")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("modeldata")) withAutoprint(\{ # examplesIf} # Although `glm()` only has a formula interface, different # methods for specifying the model can be used diff --git a/man/fit_control.Rd b/man/fit_control.Rd index aba2781d9..1b3f89fe7 100644 --- a/man/fit_control.Rd +++ b/man/fit_control.Rd @@ -32,7 +32,7 @@ output and computations \code{fit_control()} is deprecated in favor of \code{control_parsnip()}. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} fit_control(verbosity = 2L) \dontshow{\}) # examplesIf} } diff --git a/man/gen_additive_mod.Rd b/man/gen_additive_mod.Rd index 8829dcb3f..b038c7ced 100644 --- a/man/gen_additive_mod.Rd +++ b/man/gen_additive_mod.Rd @@ -54,7 +54,7 @@ gen_additive_mod(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("gen_additive_mod") gen_additive_mod() diff --git a/man/get_model_env.Rd b/man/get_model_env.Rd index 3b20d554d..d9b8c6458 100644 --- a/man/get_model_env.Rd +++ b/man/get_model_env.Rd @@ -29,7 +29,7 @@ These functions read and write to the environment where the package stores information about model specifications. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # Access the model data: current_code <- get_model_env() ls(envir = current_code) diff --git a/man/glm_grouped.Rd b/man/glm_grouped.Rd index e8697c209..3557cbfbc 100644 --- a/man/glm_grouped.Rd +++ b/man/glm_grouped.Rd @@ -36,7 +36,7 @@ each factor level so that the outcome can be given to the formula as "number of events" format for binomial data. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} #---------------------------------------------------------------------------- # The same data set formatted three ways diff --git a/man/has_multi_predict.Rd b/man/has_multi_predict.Rd index 37279991a..5525f296e 100644 --- a/man/has_multi_predict.Rd +++ b/man/has_multi_predict.Rd @@ -44,7 +44,7 @@ returns the names of the arguments to \code{multi_predict()} for this model (if any). } \examples{ -\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("kknn")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("kknn")) withAutoprint(\{ # examplesIf} lm_model_idea <- linear_reg() |> set_engine("lm") has_multi_predict(lm_model_idea) lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars) diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index e34f45652..a29cb157c 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -54,7 +54,7 @@ linear_reg(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("linear_reg") linear_reg() diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index b0e464906..e98ef99bc 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -67,7 +67,7 @@ This model fits a classification model for binary outcomes; for multiclass outcomes, see \code{\link[=multinom_reg]{multinom_reg()}}. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("logistic_reg") logistic_reg() diff --git a/man/mars.Rd b/man/mars.Rd index 2b893b066..748378f72 100644 --- a/man/mars.Rd +++ b/man/mars.Rd @@ -56,7 +56,7 @@ mars(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("mars") mars(mode = "regression", num_terms = 5) diff --git a/man/max_mtry_formula.Rd b/man/max_mtry_formula.Rd index 6dbe33e76..b89dd7c7a 100644 --- a/man/max_mtry_formula.Rd +++ b/man/max_mtry_formula.Rd @@ -24,7 +24,7 @@ This function potentially caps the value of \code{mtry} based on a formula and data set. This is a safe approach for survival and/or multivariate models. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # should be 9 max_mtry_formula(200, cbind(wt, mpg) ~ ., data = mtcars) \dontshow{\}) # examplesIf} diff --git a/man/min_cols.Rd b/man/min_cols.Rd index edb5d9d8a..14ca91b4e 100644 --- a/man/min_cols.Rd +++ b/man/min_cols.Rd @@ -31,7 +31,7 @@ fit. These functions check the possible range of the data and adjust them if needed (with a warning). } \examples{ -\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("kknn") & rlang::is_installed("ranger")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check() & rlang::is_installed("kknn") & rlang::is_installed("ranger")) withAutoprint(\{ # examplesIf} nearest_neighbor(neighbors= 100) |> set_engine("kknn") |> set_mode("regression") |> diff --git a/man/mlp.Rd b/man/mlp.Rd index 842de7ab0..12e3a848c 100644 --- a/man/mlp.Rd +++ b/man/mlp.Rd @@ -71,7 +71,7 @@ mlp(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("mlp") mlp(mode = "classification", penalty = 0.01) diff --git a/man/model_db.Rd b/man/model_db.Rd index ec674b9ea..e658f8adf 100644 --- a/man/model_db.Rd +++ b/man/model_db.Rd @@ -12,7 +12,7 @@ This is used in the RStudio add-in and captures information about mode specifications in various R packages. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} data(model_db) \dontshow{\}) # examplesIf} } diff --git a/man/model_fit.Rd b/man/model_fit.Rd index 5265cb7cd..350540ee2 100644 --- a/man/model_fit.Rd +++ b/man/model_fit.Rd @@ -40,7 +40,7 @@ This class and structure is the basis for how \pkg{parsnip} stores model objects after seeing the data and applying a model. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # Keep the `x` matrix if the data are not too big. spec_obj <- diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index 7f6314f26..eabf5ce52 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -66,7 +66,7 @@ This model fits a classification model for multiclass outcomes; for binary outcomes, see \code{\link[=logistic_reg]{logistic_reg()}}. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("multinom_reg") multinom_reg() diff --git a/man/nearest_neighbor.Rd b/man/nearest_neighbor.Rd index 9aadaa6b9..f7ce1bc47 100644 --- a/man/nearest_neighbor.Rd +++ b/man/nearest_neighbor.Rd @@ -60,7 +60,7 @@ nearest_neighbor(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("nearest_neighbor") nearest_neighbor(neighbors = 11) diff --git a/man/null_model.Rd b/man/null_model.Rd index c12a03f1f..e89962bf3 100644 --- a/man/null_model.Rd +++ b/man/null_model.Rd @@ -64,7 +64,7 @@ call. For this type of model, the template of the fit calls are below: } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} null_model(mode = "regression") \dontshow{\}) # examplesIf} } diff --git a/man/nullmodel.Rd b/man/nullmodel.Rd index 7585422de..21d8e4c38 100644 --- a/man/nullmodel.Rd +++ b/man/nullmodel.Rd @@ -57,7 +57,7 @@ probabilities are requested, the percentage of the training set samples with the most prevalent class is returned. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} outcome <- factor(sample(letters[1:2], size = 100, diff --git a/man/parsnip-package.Rd b/man/parsnip-package.Rd index 2e074dc3b..86ce93064 100644 --- a/man/parsnip-package.Rd +++ b/man/parsnip-package.Rd @@ -30,7 +30,7 @@ Authors: Other contributors: \itemize{ \item Emil Hvitfeldt \email{emil.hvitfeldt@posit.co} [contributor] - \item Posit Software, PBC (03wc8by49) [copyright holder, funder] + \item Posit Software, PBC (\href{https://ror.org/03wc8by49}{ROR}) [copyright holder, funder] } } diff --git a/man/parsnip_update.Rd b/man/parsnip_update.Rd index d6e652c82..0a7782b62 100644 --- a/man/parsnip_update.Rd +++ b/man/parsnip_update.Rd @@ -471,7 +471,7 @@ If parameters of a model specification need to be modified, \code{update()} can be used in lieu of recreating the object from scratch. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # ------------------------------------------------------------------------------ @@ -480,7 +480,7 @@ model update(model, trees = 1) update(model, trees = 1, fresh = TRUE) \dontshow{\}) # examplesIf} -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # ------------------------------------------------------------------------------ @@ -489,13 +489,13 @@ model update(model, committees = 1) update(model, committees = 1, fresh = TRUE) \dontshow{\}) # examplesIf} -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} model <- pls(predictor_prop = 0.1) model update(model, predictor_prop = 1) update(model, predictor_prop = 1, fresh = TRUE) \dontshow{\}) # examplesIf} -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # ------------------------------------------------------------------------------ model <- rule_fit(trees = 10, min_n = 2) @@ -503,7 +503,7 @@ model update(model, trees = 1) update(model, trees = 1, fresh = TRUE) \dontshow{\}) # examplesIf} -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} model <- boost_tree(mtry = 10, min_n = 3) model update(model, mtry = 1) diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index 15bce19f8..f30fe9e56 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -138,7 +138,7 @@ produces. Set \code{increasing = FALSE} to suppress this behavior. } } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} library(dplyr) lm_model <- diff --git a/man/proportional_hazards.Rd b/man/proportional_hazards.Rd index de1b966ee..4283d7c46 100644 --- a/man/proportional_hazards.Rd +++ b/man/proportional_hazards.Rd @@ -65,7 +65,7 @@ survival model be specified via the formula interface. Proportional hazards models include the Cox model. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("proportional_hazards") proportional_hazards(mode = "censored regression") diff --git a/man/rand_forest.Rd b/man/rand_forest.Rd index 1ec974b0e..5ac829482 100644 --- a/man/rand_forest.Rd +++ b/man/rand_forest.Rd @@ -58,7 +58,7 @@ rand_forest(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("rand_forest") rand_forest(mode = "classification", trees = 2000) diff --git a/man/repair_call.Rd b/man/repair_call.Rd index 5c153b4f1..9b3e40ed1 100644 --- a/man/repair_call.Rd +++ b/man/repair_call.Rd @@ -28,7 +28,7 @@ other functions. For example, some arguments may still be quosures and the functions and methods. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} fitted_model <- linear_reg() |> diff --git a/man/required_pkgs.model_spec.Rd b/man/required_pkgs.model_spec.Rd index 4460576d5..d29fee4e1 100644 --- a/man/required_pkgs.model_spec.Rd +++ b/man/required_pkgs.model_spec.Rd @@ -23,7 +23,7 @@ A character vector Determine required packages for a model } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} should_fail <- try(required_pkgs(linear_reg(engine = NULL)), silent = TRUE) should_fail diff --git a/man/rmd/discrim_linear_sparsediscrim.md b/man/rmd/discrim_linear_sparsediscrim.md index 3af4442ff..b7efb62f4 100644 --- a/man/rmd/discrim_linear_sparsediscrim.md +++ b/man/rmd/discrim_linear_sparsediscrim.md @@ -41,7 +41,7 @@ discrim_linear(regularization_method = character(0)) |> ## ## Model fit template: ## discrim::fit_regularized_linear(x = missing_arg(), y = missing_arg(), -## method = character(0)) +## regularization_method = character(0)) ``` ## Preprocessing requirements diff --git a/man/rmd/discrim_quad_sparsediscrim.md b/man/rmd/discrim_quad_sparsediscrim.md index b14b1bf63..2d663468c 100644 --- a/man/rmd/discrim_quad_sparsediscrim.md +++ b/man/rmd/discrim_quad_sparsediscrim.md @@ -40,7 +40,7 @@ discrim_quad(regularization_method = character(0)) |> ## ## Model fit template: ## discrim::fit_regularized_quad(x = missing_arg(), y = missing_arg(), -## method = character(0)) +## regularization_method = character(0)) ``` ## Preprocessing requirements diff --git a/man/rmd/rand_forest_grf.Rmd b/man/rmd/rand_forest_grf.Rmd index 6422dd5f7..a22382cb7 100644 --- a/man/rmd/rand_forest_grf.Rmd +++ b/man/rmd/rand_forest_grf.Rmd @@ -85,6 +85,8 @@ rand_forest( This method _does_ require qualitative predictors to be converted to a numeric format (manually). When using parsnip, a one-hot encoding is automatically used to do this. +If there are missing values in the predictors, the model will use case-wise deletion to remove them. + ## Other notes By default, parallel processing is turned off. When tuning, it is more efficient to parallelize over the resamples and tuning parameters. To parallelize the construction of the trees within the `grf` model, change the `num.threads` argument via [set_engine()]. diff --git a/man/rmd/rand_forest_grf.md b/man/rmd/rand_forest_grf.md index 5c00b9b90..0c0c0c340 100644 --- a/man/rmd/rand_forest_grf.md +++ b/man/rmd/rand_forest_grf.md @@ -44,7 +44,7 @@ rand_forest( ## Computational engine: grf ## ## Model fit template: -## grf::regression_forest(x = missing_arg(), y = missing_arg(), +## grf::regression_forest(X = missing_arg(), Y = missing_arg(), ## weights = missing_arg(), mtry = min_cols(~integer(1), x), ## num.trees = integer(1), min.node.size = min_rows(~integer(1), ## x), num.threads = 1) @@ -77,7 +77,7 @@ rand_forest( ## Computational engine: grf ## ## Model fit template: -## grf::probability_forest(x = missing_arg(), y = missing_arg(), +## grf::probability_forest(X = missing_arg(), Y = missing_arg(), ## weights = missing_arg(), mtry = min_cols(~integer(1), x), ## num.trees = integer(1), min.node.size = min_rows(~integer(1), ## x), num.threads = 1) @@ -112,10 +112,9 @@ rand_forest( ## Computational engine: grf ## ## Model fit template: -## grf::quantile_forest(x = missing_arg(), y = missing_arg(), weights = missing_arg(), -## mtry = min_cols(~integer(1), x), num.trees = integer(1), -## min.node.size = min_rows(~integer(1), x), num.threads = 1, -## quantiles = quantile_levels) +## grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1), +## x), num.trees = integer(1), min.node.size = min_rows(~integer(1), +## x), num.threads = 1, quantiles = quantile_levels) ``` ``` @@ -126,6 +125,8 @@ rand_forest( This method _does_ require qualitative predictors to be converted to a numeric format (manually). When using parsnip, a one-hot encoding is automatically used to do this. +If there are missing values in the predictors, the model will use case-wise deletion to remove them. + ## Other notes By default, parallel processing is turned off. When tuning, it is more efficient to parallelize over the resamples and tuning parameters. To parallelize the construction of the trees within the `grf` model, change the `num.threads` argument via [set_engine()]. diff --git a/man/rule_fit.Rd b/man/rule_fit.Rd index 5fc65b082..7cce87dbf 100644 --- a/man/rule_fit.Rd +++ b/man/rule_fit.Rd @@ -89,7 +89,7 @@ rule_fit(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("rule_fit") rule_fit() diff --git a/man/set_args.Rd b/man/set_args.Rd index 5a9bca054..692c48638 100644 --- a/man/set_args.Rd +++ b/man/set_args.Rd @@ -36,7 +36,7 @@ An updated model object. \code{set_args()} will replace existing values of the arguments. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} rand_forest() rand_forest() |> diff --git a/man/set_engine.Rd b/man/set_engine.Rd index 274703754..c1a838be2 100644 --- a/man/set_engine.Rd +++ b/man/set_engine.Rd @@ -54,7 +54,7 @@ argument to be passed directly to the engine fitting function, like } } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # First, set main arguments using the standardized names logistic_reg(penalty = 0.01, mixture = 1/3) |> # Now specify how you want to fit the model with another argument diff --git a/man/set_new_model.Rd b/man/set_new_model.Rd index bb08bb812..d2aa1c95f 100644 --- a/man/set_new_model.Rd +++ b/man/set_new_model.Rd @@ -145,7 +145,7 @@ accommodate a sparse matrix representation for predictors during fitting and tuning. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # set_new_model("shallow_learning_model") # Show the information about a model: diff --git a/man/show_engines.Rd b/man/show_engines.Rd index 31656ed73..14f5fc735 100644 --- a/man/show_engines.Rd +++ b/man/show_engines.Rd @@ -19,7 +19,7 @@ the \pkg{poissonreg} package adds additional engines for the \code{\link[=poisso model and these are not available unless \pkg{poissonreg} is loaded. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("linear_reg") \dontshow{\}) # examplesIf} } diff --git a/man/survival_reg.Rd b/man/survival_reg.Rd index 286b5f989..d1494e756 100644 --- a/man/survival_reg.Rd +++ b/man/survival_reg.Rd @@ -47,7 +47,7 @@ Since survival models typically involve censoring (and require the use of survival model be specified via the formula interface. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("survival_reg") survival_reg(mode = "censored regression", dist = "weibull") diff --git a/man/svm_linear.Rd b/man/svm_linear.Rd index 4a0ff97e4..4fab79222 100644 --- a/man/svm_linear.Rd +++ b/man/svm_linear.Rd @@ -50,7 +50,7 @@ svm_linear(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("svm_linear") svm_linear(mode = "classification") diff --git a/man/svm_poly.Rd b/man/svm_poly.Rd index 071dac91e..34fce7737 100644 --- a/man/svm_poly.Rd +++ b/man/svm_poly.Rd @@ -62,7 +62,7 @@ svm_poly(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("svm_poly") svm_poly(mode = "classification", degree = 1.2) diff --git a/man/svm_rbf.Rd b/man/svm_rbf.Rd index 3be0777b1..77ec54772 100644 --- a/man/svm_rbf.Rd +++ b/man/svm_rbf.Rd @@ -60,7 +60,7 @@ svm_rbf(argument = !!value) }\if{html}{\out{}} } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} show_engines("svm_rbf") svm_rbf(mode = "classification", rbf_sigma = 0.2) diff --git a/man/tidy.nullmodel.Rd b/man/tidy.nullmodel.Rd index 81d078004..482af6295 100644 --- a/man/tidy.nullmodel.Rd +++ b/man/tidy.nullmodel.Rd @@ -18,7 +18,7 @@ A tibble with column \code{value}. Return the results of \code{nullmodel} as a tibble } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} nullmodel(mtcars[,-1], mtcars$mpg) |> tidy() \dontshow{\}) # examplesIf} diff --git a/man/translate.Rd b/man/translate.Rd index 072486d6f..1a5ec154e 100644 --- a/man/translate.Rd +++ b/man/translate.Rd @@ -42,7 +42,7 @@ to understand what the underlying syntax would be. It should not be used to modify the model specification. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} lm_spec <- linear_reg(penalty = 0.01) # `penalty` is translated to `lambda` diff --git a/man/varying_args.Rd b/man/varying_args.Rd index c019d947f..f24852356 100644 --- a/man/varying_args.Rd +++ b/man/varying_args.Rd @@ -38,7 +38,7 @@ or a \code{recipe} is used. For a \code{model_spec}, the first class is used. Fo a \code{recipe}, the unique step \code{id} is used. } \examples{ -\dontshow{if (!parsnip:::is_cran_check()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (!parsnip:::is_cran_check()) withAutoprint(\{ # examplesIf} # List all possible varying args for the random forest spec rand_forest() |> varying_args() From 7ed3c0fe93b01a2a18bdda2be5e1cff4dbbf6cca Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 27 Oct 2025 17:07:21 -0400 Subject: [PATCH 10/15] fix test --- man/details_discrim_linear_sparsediscrim.Rd | 2 +- man/details_discrim_quad_sparsediscrim.Rd | 2 +- man/details_rand_forest_grf.Rd | 14 ++++++++------ tests/testthat/_snaps/registration.md | 2 +- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/man/details_discrim_linear_sparsediscrim.Rd b/man/details_discrim_linear_sparsediscrim.Rd index 3597eebcd..f1f79bacb 100644 --- a/man/details_discrim_linear_sparsediscrim.Rd +++ b/man/details_discrim_linear_sparsediscrim.Rd @@ -51,7 +51,7 @@ discrim_linear(regularization_method = character(0)) |> ## ## Model fit template: ## discrim::fit_regularized_linear(x = missing_arg(), y = missing_arg(), -## method = character(0)) +## regularization_method = character(0)) }\if{html}{\out{}} } diff --git a/man/details_discrim_quad_sparsediscrim.Rd b/man/details_discrim_quad_sparsediscrim.Rd index 489dd06f4..f101499bb 100644 --- a/man/details_discrim_quad_sparsediscrim.Rd +++ b/man/details_discrim_quad_sparsediscrim.Rd @@ -49,7 +49,7 @@ discrim_quad(regularization_method = character(0)) |> ## ## Model fit template: ## discrim::fit_regularized_quad(x = missing_arg(), y = missing_arg(), -## method = character(0)) +## regularization_method = character(0)) }\if{html}{\out{}} } diff --git a/man/details_rand_forest_grf.Rd b/man/details_rand_forest_grf.Rd index 85b46767d..67fc28804 100644 --- a/man/details_rand_forest_grf.Rd +++ b/man/details_rand_forest_grf.Rd @@ -50,7 +50,7 @@ See ## Computational engine: grf ## ## Model fit template: -## grf::regression_forest(x = missing_arg(), y = missing_arg(), +## grf::regression_forest(X = missing_arg(), Y = missing_arg(), ## weights = missing_arg(), mtry = min_cols(~integer(1), x), ## num.trees = integer(1), min.node.size = min_rows(~integer(1), ## x), num.threads = 1) @@ -82,7 +82,7 @@ See ## Computational engine: grf ## ## Model fit template: -## grf::probability_forest(x = missing_arg(), y = missing_arg(), +## grf::probability_forest(X = missing_arg(), Y = missing_arg(), ## weights = missing_arg(), mtry = min_cols(~integer(1), x), ## num.trees = integer(1), min.node.size = min_rows(~integer(1), ## x), num.threads = 1) @@ -117,10 +117,9 @@ the quantile levels \emph{a priori}. ## Computational engine: grf ## ## Model fit template: -## grf::quantile_forest(x = missing_arg(), y = missing_arg(), weights = missing_arg(), -## mtry = min_cols(~integer(1), x), num.trees = integer(1), -## min.node.size = min_rows(~integer(1), x), num.threads = 1, -## quantiles = quantile_levels) +## grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1), +## x), num.trees = integer(1), min.node.size = min_rows(~integer(1), +## x), num.threads = 1, quantiles = quantile_levels) ## Quantile levels: 0.25, 0.5, and 0.75. }\if{html}{\out{}} @@ -131,6 +130,9 @@ the quantile levels \emph{a priori}. This method \emph{does} require qualitative predictors to be converted to a numeric format (manually). When using parsnip, a one-hot encoding is automatically used to do this. + +If there are missing values in the predictors, the model will use +case-wise deletion to remove them. } \subsection{Other notes}{ diff --git a/tests/testthat/_snaps/registration.md b/tests/testthat/_snaps/registration.md index 9e90b65b7..c3523e4f0 100644 --- a/tests/testthat/_snaps/registration.md +++ b/tests/testthat/_snaps/registration.md @@ -367,7 +367,7 @@ engines: classification: grf1, randomForest, ranger1, spark - quantile regression: grf1 + quantile regression: grf regression: grf1, randomForest, ranger1, spark 1The model can use case weights. From 85232657e1b853b41c23b5505ac72db082cacb7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= Date: Mon, 27 Oct 2025 20:30:11 -0400 Subject: [PATCH 11/15] remove quantreg from suggests --- DESCRIPTION | 1 - 1 file changed, 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 2291795fa..89f6e2b5a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -59,7 +59,6 @@ Suggests: modeldata, nlme, prodlim, - quantreg, ranger (>= 0.12.0), remotes, rmarkdown, From d37a51bde4b08a6dfd6389be49f70140be630dc6 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 28 Oct 2025 08:50:30 -0400 Subject: [PATCH 12/15] check for quantreg install --- R/augment.R | 20 +++++++++++--------- man/augment.Rd | 20 +++++++++++--------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/R/augment.R b/R/augment.R index e87578266..563d0cef4 100644 --- a/R/augment.R +++ b/R/augment.R @@ -86,15 +86,17 @@ #' #' # ------------------------------------------------------------------------------ #' -#' # Quantile regression example -#' qr_form <- -#' linear_reg() |> -#' set_engine("quantreg") |> -#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |> -#' fit(mpg ~ ., data = car_trn) -#' -#' augment(qr_form, car_tst) -#' augment(qr_form, car_tst[, -1]) +#' if (rlang::is_installed("quantreg")) { +#' # Quantile regression example +#' qr_form <- +#' linear_reg() |> +#' set_engine("quantreg") |> +#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |> +#' fit(mpg ~ ., data = car_trn) +#' +#' augment(qr_form, car_tst) +#' augment(qr_form, car_tst[, -1]) +#' } #' augment.model_fit <- function(x, new_data, eval_time = NULL, ...) { new_data <- tibble::new_tibble(new_data) diff --git a/man/augment.Rd b/man/augment.Rd index 49ea151df..7e070f10f 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -99,15 +99,17 @@ augment(cls_xy, cls_tst[, -3]) # ------------------------------------------------------------------------------ -# Quantile regression example -qr_form <- - linear_reg() |> - set_engine("quantreg") |> - set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |> - fit(mpg ~ ., data = car_trn) - -augment(qr_form, car_tst) -augment(qr_form, car_tst[, -1]) +if (rlang::is_installed("quantreg")) { + # Quantile regression example + qr_form <- + linear_reg() |> + set_engine("quantreg") |> + set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |> + fit(mpg ~ ., data = car_trn) + + augment(qr_form, car_tst) + augment(qr_form, car_tst[, -1]) +} \dontshow{\}) # examplesIf} } \references{ From 264b7f286a5d581331038790772bffc0c0f43260 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 29 Oct 2025 21:27:54 -0400 Subject: [PATCH 13/15] fix errors --- vignettes/articles/Examples.Rmd | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vignettes/articles/Examples.Rmd b/vignettes/articles/Examples.Rmd index d03c30109..bdb979936 100644 --- a/vignettes/articles/Examples.Rmd +++ b/vignettes/articles/Examples.Rmd @@ -1437,8 +1437,8 @@ The following examples use consistent data sets throughout. For regression, we u ```{r} mr_cls_spec <- - multinom_reg(penalty = 0.1) |> - set_engine("brulee") + multinom_reg() |> + set_engine("brulee", learn_rate = 0.01, optimizer = "SGD") mr_cls_spec ``` @@ -1759,7 +1759,7 @@ The following examples use consistent data sets throughout. For regression, we u ```{r} bind_cols( predict(rf_cls_fit, data_test), - predict(rf_cls_fit, data_test, type = "prob") + predict(rf_cls_fit, data_test, type = "prob"), predict(rf_cls_fit, data_test, type = "conf_int") ) ``` @@ -1775,7 +1775,7 @@ The following examples use consistent data sets throughout. For regression, we u ```{r} grf_quant_spec <- - linear_reg() |> + rand_forest() |> set_engine("grf") |> set_mode("quantile regression", quantile_levels = (1:3) / 4) grf_quant_spec From 7a2cadf12706d5336884b85638376e48f822f8ff Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 30 Oct 2025 08:08:32 -0400 Subject: [PATCH 14/15] add grf to Config/Needs/website --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 89f6e2b5a..a745e819f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -71,7 +71,7 @@ Suggests: VignetteBuilder: knitr ByteCompile: true -Config/Needs/website: brulee, C50, dbarts, earth, glmnet, keras, kernlab, +Config/Needs/website: brulee, C50, dbarts, earth, glmnet, grf, keras, kernlab, kknn, LiblineaR, mgcv, nnet, parsnip, quantreg, randomForest, ranger, rpart, rstanarm, tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate, xgboost, rmarkdown From c4677f45fe6bcbcd24c1c001f89ece04e2f9a773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= Date: Fri, 31 Oct 2025 10:03:04 -0400 Subject: [PATCH 15/15] changes based on user comments --- R/rand_forest_data.R | 2 +- R/rand_forest_grf.R | 4 ++-- man/details_rand_forest_grf.Rd | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 9b695da97..10ee4a5c0 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -148,7 +148,7 @@ grf_conf_int <- function( res <- vctrs::vec_cbind(res, std_err) } } else { - rlang::abort( + cli::cli_abort( "No confidence interval implementation for objects with class(es) {.cls {class(object$fit)[1]}}" ) diff --git a/R/rand_forest_grf.R b/R/rand_forest_grf.R index f7b14388e..355c048d1 100644 --- a/R/rand_forest_grf.R +++ b/R/rand_forest_grf.R @@ -1,7 +1,7 @@ -#' Random forests via grf +#' Generalized random forests via grf #' #' The \pkg{grf} fits models that create a large number of decision -#' trees, each independent of the others. The final prediction uses all +#' trees, each independent of the others. The final prediction uses #' predictions from the individual trees and combines them. #' #' @includeRmd man/rmd/rand_forest_grf.md details diff --git a/man/details_rand_forest_grf.Rd b/man/details_rand_forest_grf.Rd index 67fc28804..e0a228df9 100644 --- a/man/details_rand_forest_grf.Rd +++ b/man/details_rand_forest_grf.Rd @@ -2,10 +2,10 @@ % Please edit documentation in R/rand_forest_grf.R \name{details_rand_forest_grf} \alias{details_rand_forest_grf} -\title{Random forests via grf} +\title{Generalized random forests via grf} \description{ The \pkg{grf} fits models that create a large number of decision -trees, each independent of the others. The final prediction uses all +trees, each independent of the others. The final prediction uses predictions from the individual trees and combines them. } \details{