From c1dbac0f4efebd2a49120cb0d8f5c49265a480a4 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 29 Apr 2020 21:29:58 -0400 Subject: [PATCH 01/25] initial work on #290 --- NAMESPACE | 2 + R/aaa_models.R | 79 +++++++++++++++++++++++++++++++++ R/linear_reg_data.R | 102 ++++++++++++++++++++++++++++--------------- man/set_new_model.Rd | 6 +++ 4 files changed, 155 insertions(+), 34 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index ab16d61d5..8ad392628 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -108,6 +108,7 @@ export(fit_control) export(fit_xy) export(fit_xy.model_spec) export(get_dependency) +export(get_encoding) export(get_fit) export(get_from_env) export(get_model_env) @@ -143,6 +144,7 @@ export(rand_forest) export(rpart_train) export(set_args) export(set_dependency) +export(set_encoding) export(set_engine) export(set_env_val) export(set_fit) diff --git a/R/aaa_models.R b/R/aaa_models.R index 6bc1e1f72..739453308 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -302,6 +302,11 @@ check_interface_val <- function(x) { #' below, depending on context. #' @param pre,post Optional functions for pre- and post-processing of prediction #' results. +#' @param options A list of options for encodings. The current option is +#' `predictor_indicators` which tells `parsnip` whether the pre-processing +#' should make dummy variables from factor predictors. This only affects cases +#' when [fit.model_spec()] is used and the underlying model has the x/y +#' interface. #' @param ... Optional arguments that should be passed into the `args` slot for #' prediction objects. #' @keywords internal @@ -759,3 +764,77 @@ pred_value_template <- function(pre = NULL, post = NULL, func, ...) { list(pre = pre, post = post, func = func, args = list(...)) } +# ------------------------------------------------------------------------------ + +check_encodings <- function(x) { + if (!is.list(x)) { + rlang::abort("`values` should be a list.") + } + req_args <- list(predictor_indicators = TRUE) + + missing_args <- setdiff(names(req_args), names(x)) + if (length(missing_args) > 0) { + rlang::abort( + glue::glue( + "The values passed to `set_encoding()` are missing arguments: ", + paste0("'", missing_args, "'", collapse = ", ") + ) + ) + } + extra_args <- setdiff(names(x), names(req_args)) + if (length(extra_args) > 0) { + rlang::abort( + glue::glue( + "The values passed to `set_encoding()` had extra arguments: ", + paste0("'", extra_args, "'", collapse = ", ") + ) + ) + } + invisible(x) +} + +#' @export +#' @rdname set_new_model +#' @keywords internal +set_encoding <- function(model, mode, eng, options) { + check_model_exists(model) + check_eng_val(eng) + check_mode_val(mode) + check_encodings(options) + + keys <- tibble::tibble(model = model, engine = eng, mode = mode) + options <- tibble::as_tibble(options) + new_values <- dplyr::bind_cols(keys, options) + + + current_db_list <- ls(envir = get_model_env()) + nm <- paste(model, "encoding", sep = "_") + if (any(current_db_list == nm)) { + current <- get_from_env(nm) + dup_check <- + current %>% + dplyr::inner_join(new_values, by = c("model", "engine", "mode", "predictor_indicators")) + if (nrow(dup_check)) { + rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings.")) + } + + } else { + current <- NULL + } + + db_values <- dplyr::bind_rows(current, new_values) + set_env_val(nm, db_values) + + invisible(NULL) +} + + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_encoding <- function(model) { + check_model_exists(model) + nm <- paste0(model, "_encoding") + rlang::env_get(get_model_env(), nm) +} + diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 9d8fa143a..7418e98df 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -19,6 +19,13 @@ set_fit( ) ) +set_encoding( + model = "linear_reg", + eng = "lm", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "linear_reg", eng = "lm", @@ -102,6 +109,25 @@ set_pred( set_model_engine("linear_reg", "regression", "glmnet") set_dependency("linear_reg", "glmnet", "glmnet") +set_fit( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "gaussian") + ) +) + +set_encoding( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_model_arg( model = "linear_reg", eng = "glmnet", @@ -120,18 +146,6 @@ set_model_arg( has_submodel = FALSE ) -set_fit( - model = "linear_reg", - eng = "glmnet", - mode = "regression", - value = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = list(family = "gaussian") - ) -) - set_pred( model = "linear_reg", eng = "glmnet", @@ -183,6 +197,13 @@ set_fit( ) ) +set_encoding( + model = "linear_reg", + eng = "stan", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "linear_reg", eng = "stan", @@ -287,6 +308,25 @@ set_pred( set_model_engine("linear_reg", "regression", "spark") set_dependency("linear_reg", "spark", "sparklyr") +set_fit( + model = "linear_reg", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_linear_regression"), + defaults = list() + ) +) + +set_encoding( + model = "linear_reg", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_model_arg( model = "linear_reg", eng = "spark", @@ -305,19 +345,6 @@ set_model_arg( has_submodel = FALSE ) - -set_fit( - model = "linear_reg", - eng = "spark", - mode = "regression", - value = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_linear_regression"), - defaults = list() - ) -) - set_pred( model = "linear_reg", eng = "spark", @@ -342,15 +369,6 @@ set_model_engine("linear_reg", "regression", "keras") set_dependency("linear_reg", "keras", "keras") set_dependency("linear_reg", "keras", "magrittr") -set_model_arg( - model = "linear_reg", - eng = "keras", - parsnip = "penalty", - original = "penalty", - func = list(pkg = "dials", fun = "penalty"), - has_submodel = FALSE -) - set_fit( model = "linear_reg", eng = "keras", @@ -363,6 +381,22 @@ set_fit( ) ) +set_encoding( + model = "linear_reg", + eng = "keras", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + +set_model_arg( + model = "linear_reg", + eng = "keras", + parsnip = "penalty", + original = "penalty", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = FALSE +) + set_pred( model = "linear_reg", eng = "keras", diff --git a/man/set_new_model.Rd b/man/set_new_model.Rd index 385c3f83b..fdb431e0d 100644 --- a/man/set_new_model.Rd +++ b/man/set_new_model.Rd @@ -13,6 +13,8 @@ \alias{get_pred_type} \alias{show_model_info} \alias{pred_value_template} +\alias{set_encoding} +\alias{get_encoding} \title{Tools to Register Models} \usage{ set_new_model(model) @@ -38,6 +40,10 @@ get_pred_type(model, type) show_model_info(model) pred_value_template(pre = NULL, post = NULL, func, ...) + +set_encoding(model, mode, eng, options) + +get_encoding(model) } \arguments{ \item{model}{A single character string for the model type (e.g. From eeac337a6919dcd35c5fb635bfd94b64cfb49b17 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 12:56:30 -0600 Subject: [PATCH 02/25] Fix more tests for tidyr and tibble --- tests/testthat/test_logistic_reg_glmnet.R | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index 47b280de4..77cf2c06b 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -261,7 +261,8 @@ test_that('glmnet probabilities, one lambda', { predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")], type = "prob") ) - one_row <- predict(res_form, lending_club[1, c("funded_amnt", "int_rate")], type = "prob") + one_row <- predict(res_form, lending_club[1, c("funded_amnt", "int_rate")], type = "prob") %>% + mutate_all(set_names, 1) expect_equivalent(form_pred[1,], one_row) }) @@ -358,7 +359,8 @@ test_that('glmnet probabilities, no lambda', { expect_equal( mult_pred, - multi_predict(xy_fit, lending_club[1:7, num_pred], type = "prob") %>% unnest() + multi_predict(xy_fit, lending_club[1:7, num_pred], type = "prob") %>% + unnest(cols = c(.pred)) ) res_form <- fit( From 16a0c85353b6707e9c5a465f500da29db80255a5 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 12:57:43 -0600 Subject: [PATCH 03/25] Fine-tune documentation --- R/aaa_models.R | 10 +++++----- man/set_new_model.Rd | 6 ++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 739453308..46f0bd497 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -302,11 +302,11 @@ check_interface_val <- function(x) { #' below, depending on context. #' @param pre,post Optional functions for pre- and post-processing of prediction #' results. -#' @param options A list of options for encodings. The current option is -#' `predictor_indicators` which tells `parsnip` whether the pre-processing -#' should make dummy variables from factor predictors. This only affects cases -#' when [fit.model_spec()] is used and the underlying model has the x/y -#' interface. +#' @param options A list of options for engine-specific encodings. Currently, +#' the option implemented is `predictor_indicators` which tells `parsnip` +#' whether the pre-processing should make indicator/dummy variables from factor +#' predictors. This only affects cases when [fit.model_spec()] is used and the +#' underlying model has an x/y interface. #' @param ... Optional arguments that should be passed into the `args` slot for #' prediction objects. #' @keywords internal diff --git a/man/set_new_model.Rd b/man/set_new_model.Rd index fdb431e0d..8d9642cf0 100644 --- a/man/set_new_model.Rd +++ b/man/set_new_model.Rd @@ -85,6 +85,12 @@ results.} \item{...}{Optional arguments that should be passed into the \code{args} slot for prediction objects.} +\item{options}{A list of options for engine-specific encodings. Currently, +the option implemented is \code{predictor_indicators} which tells \code{parsnip} +whether the pre-processing should make indicator/dummy variables from factor +predictors. This only affects cases when \code{\link[=fit.model_spec]{fit.model_spec()}} is used and the +underlying model has an x/y interface.} + \item{arg}{A single character string for the model argument name.} \item{fit_obj}{A list with elements \code{interface}, \code{protect}, From 99137acfd7e5d552c70c14d548ea9b4f0e7b522c Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 13:26:06 -0600 Subject: [PATCH 04/25] Engine encoding for logistic_reg() --- R/logistic_reg_data.R | 60 +++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index ec5228b85..e6ab37155 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -19,6 +19,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "glm", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "glm", @@ -121,6 +128,25 @@ set_pred( set_model_engine("logistic_reg", "classification", "glmnet") set_dependency("logistic_reg", "glmnet", "glmnet") +set_fit( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "binomial") + ) +) + +set_encoding( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_model_arg( model = "logistic_reg", eng = "glmnet", @@ -139,19 +165,6 @@ set_model_arg( has_submodel = FALSE ) -set_fit( - model = "logistic_reg", - eng = "glmnet", - mode = "classification", - value = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = list(family = "binomial") - ) -) - - set_pred( model = "logistic_reg", eng = "glmnet", @@ -245,6 +258,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "spark", @@ -306,6 +326,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "keras", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "keras", @@ -363,6 +390,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "stan", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "stan", From 27246a6cc938c5a23dabb6b382df3aa8f26bcd7c Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 15:09:43 -0600 Subject: [PATCH 05/25] Do not want inner names --- tests/testthat/test_logistic_reg_glmnet.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index 77cf2c06b..f03cedc0a 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -261,8 +261,7 @@ test_that('glmnet probabilities, one lambda', { predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")], type = "prob") ) - one_row <- predict(res_form, lending_club[1, c("funded_amnt", "int_rate")], type = "prob") %>% - mutate_all(set_names, 1) + one_row <- predict(res_form, lending_club[1, c("funded_amnt", "int_rate")], type = "prob") expect_equivalent(form_pred[1,], one_row) }) From 972a7964ba3a0c711f694cf53ee00c6e8c68a84f Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 17:11:28 -0600 Subject: [PATCH 06/25] Indicator variables for MARS --- R/mars_data.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/R/mars_data.R b/R/mars_data.R index 29d8c97f7..be312d7de 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -47,6 +47,13 @@ set_fit( ) ) +set_encoding( + model = "mars", + eng = "earth", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "mars", eng = "earth", @@ -59,6 +66,13 @@ set_fit( ) ) +set_encoding( + model = "mars", + eng = "earth", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "mars", eng = "earth", From cd4ed7731d27a1a03f1556697f00e2aafd589701 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 17:13:00 -0600 Subject: [PATCH 07/25] Look up predictor indicator; use in convert_form_to_xy_fit() --- R/convert_data.R | 2 +- R/fit_helpers.R | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/R/convert_data.R b/R/convert_data.R index 6496a1ce9..fe02afce4 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -15,7 +15,7 @@ #' @importFrom stats .checkMFClasses .getXlevels delete.response #' @importFrom stats model.offset model.weights na.omit na.pass -convert_form_to_xy_fit <-function( +convert_form_to_xy_fit <- function( formula, data, ..., diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 9836e8193..18fbfc1b5 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -131,12 +131,16 @@ xy_xy <- function(object, env, control, target = "none", ...) { form_xy <- function(object, control, env, target = "none", ...) { + indicators <- get_encoding(class(object)[1]) %>% + dplyr::filter(mode == object$mode) %>% + dplyr::pull(predictor_indicators) + data_obj <- convert_form_to_xy_fit( formula = env$formula, data = env$data, ..., - composition = target - # indicators + composition = target, + indicators = indicators ) env$x <- data_obj$x env$y <- data_obj$y From 64bbde0a9b9dfdfc95b60cb8d58f03d5ddd39ade Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 17:22:53 -0600 Subject: [PATCH 08/25] Test indicators = FALSE compared to a model that does not create indicator variables --- tests/testthat/test_convert_data.R | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/testthat/test_convert_data.R b/tests/testthat/test_convert_data.R index 51cb4e221..c58364be2 100644 --- a/tests/testthat/test_convert_data.R +++ b/tests/testthat/test_convert_data.R @@ -168,6 +168,21 @@ test_that("numeric y and mixed x", { ) }) +test_that("mixed x, no dummies, compare to a model that does not create dummies", { + expected <- rpart::rpart(rate ~ ., data = Puromycin) + data_classes <- attr(expected$terms, "dataClasses")[2:3] + observed <- parsnip:::convert_form_to_xy_fit(rate ~ ., data = Puromycin, indicators = FALSE) + expect_equal(names(data_classes), names(observed$x)) + expect_equal(unname(data_classes), c("numeric", "factor")) + expect_s3_class(observed$x$state, "factor") + expect_equivalent(Puromycin$rate, observed$y) + expect_equal(expected$terms, observed$terms) + + expect_null(observed$weights) + expect_null(observed$offset) +}) + + test_that("numeric y and mixed x, omit missing data", { expected <- lm(rate ~ ., data = Puromycin_miss, x = TRUE, y = TRUE) observed <- parsnip:::convert_form_to_xy_fit(rate ~ ., data = Puromycin_miss) From e02573f9cc4e64a6fc00950ebac477faf8e1cdfe Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 17:47:49 -0600 Subject: [PATCH 09/25] Set predictor indicators for xgboost (TRUE) and C5.0 (FALSE) --- R/boost_tree_data.R | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 24954bd4d..580511d48 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -87,6 +87,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "boost_tree", eng = "xgboost", @@ -125,6 +132,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "boost_tree", eng = "xgboost", @@ -221,6 +235,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "boost_tree", eng = "C5.0", From ad6ac8c26f3faf11df17aa903f35537aefeb6769 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 17:51:27 -0600 Subject: [PATCH 10/25] Set predictor encodings for Spark (TRUE). --- R/boost_tree_data.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 580511d48..2aebcc7ef 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -364,6 +364,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "boost_tree", eng = "spark", @@ -376,6 +383,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "boost_tree", eng = "spark", From 2c52d2c88f718580df666c2034fd51e6b734d9a8 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 18:15:42 -0600 Subject: [PATCH 11/25] Add glue. Closes #296. --- R/fit.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/fit.R b/R/fit.R index 0967026cc..6b1acc10f 100644 --- a/R/fit.R +++ b/R/fit.R @@ -103,7 +103,7 @@ fit.model_spec <- eng_vals <- possible_engines(object) object$engine <- eng_vals[1] if (control$verbosity > 0) { - rlang::warn("Engine set to `{object$engine}`.") + rlang::warn(glue::glue("Engine set to `{object$engine}`.")) } } From c6d0d3597a0af924dc8a2fe83238702b44682a78 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 18:17:44 -0600 Subject: [PATCH 12/25] For null model, set predictor indicators to... FALSE? :thinking: --- R/nullmodel_data.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index c8a29c41c..aa6345879 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -21,6 +21,13 @@ set_fit( ) ) +set_encoding( + model = "null_model", + eng = "parsnip", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "null_model", eng = "parsnip", @@ -33,6 +40,13 @@ set_fit( ) ) +set_encoding( + model = "null_model", + eng = "parsnip", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "null_model", eng = "parsnip", From d3ee6de918c671dd69c424be4920343c764e83db Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 18 May 2020 18:18:18 -0600 Subject: [PATCH 13/25] Also need engine to find the indicator encoding --- R/fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 18fbfc1b5..ac7221d6a 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -132,7 +132,8 @@ form_xy <- function(object, control, env, target = "none", ...) { indicators <- get_encoding(class(object)[1]) %>% - dplyr::filter(mode == object$mode) %>% + dplyr::filter(mode == object$mode, + engine == object$engine) %>% dplyr::pull(predictor_indicators) data_obj <- convert_form_to_xy_fit( From 47544b47efe95e0467b3f1c08ddd7bfdd7d62879 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 08:12:12 -0600 Subject: [PATCH 14/25] Decision tree predictors = FALSE --- R/decision_tree_data.R | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index 729d6262c..841a2aa23 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -48,6 +48,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "rpart", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "decision_tree", eng = "rpart", @@ -60,6 +67,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "rpart", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "decision_tree", eng = "rpart", @@ -158,6 +172,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "decision_tree", eng = "C5.0", @@ -244,6 +265,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "decision_tree", eng = "spark", @@ -257,6 +285,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "decision_tree", eng = "spark", From ccee2131175f7bf63c4276bf6c7e22a1244f64f1 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 08:21:45 -0600 Subject: [PATCH 15/25] Neural nets all TRUE for indicators --- R/mlp_data.R | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/R/mlp_data.R b/R/mlp_data.R index 81ad6dc83..73bf4df07 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -65,6 +65,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "keras", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "mlp", eng = "keras", @@ -77,6 +84,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "keras", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "mlp", eng = "keras", @@ -191,6 +205,7 @@ set_model_arg( func = list(pkg = "dials", fun = "penalty"), has_submodel = FALSE ) + set_model_arg( model = "mlp", eng = "nnet", @@ -199,6 +214,7 @@ set_model_arg( func = list(pkg = "dials", fun = "epochs"), has_submodel = FALSE ) + set_fit( model = "mlp", eng = "nnet", @@ -211,6 +227,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "nnet", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "mlp", eng = "nnet", @@ -223,6 +246,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "nnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "mlp", eng = "nnet", From 378045421d0231ca8cd4f316309ed566bd17bfc7 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 08:37:25 -0600 Subject: [PATCH 16/25] Predictor indicators for kknn --- R/nearest_neighbor_data.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 61cad97c8..695c09cdc 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -47,6 +47,13 @@ set_fit( ) ) +set_encoding( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "nearest_neighbor", eng = "kknn", @@ -59,6 +66,13 @@ set_fit( ) ) +set_encoding( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "nearest_neighbor", eng = "kknn", From 177977737651cb2137b72256fd1a8ede75c8ed1a Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 08:37:50 -0600 Subject: [PATCH 17/25] Predictor indicators for multinomial classification --- R/multinom_reg_data.R | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 3f2889cc3..3b9add328 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -37,6 +37,12 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) set_pred( model = "multinom_reg", @@ -128,6 +134,13 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "multinom_reg", eng = "spark", @@ -191,6 +204,13 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "keras", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "multinom_reg", eng = "keras", @@ -254,6 +274,12 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "nnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) set_pred( model = "multinom_reg", From 7d87e454514aefe4b2e692ab33697f5c8dd7b71a Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 09:29:50 -0600 Subject: [PATCH 18/25] Random forest predictor indicators --- R/rand_forest_data.R | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index d054b1b66..5465f5d92 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -132,6 +132,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "ranger", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "rand_forest", eng = "ranger", @@ -149,6 +156,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "ranger", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "rand_forest", eng = "ranger", @@ -338,6 +352,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "rand_forest", eng = "randomForest", @@ -351,6 +372,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "rand_forest", eng = "randomForest", @@ -474,6 +502,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "rand_forest", eng = "spark", @@ -486,6 +521,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "rand_forest", eng = "spark", From 172541ca0345606150ded4fe89f91136096b63a3 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 09:37:28 -0600 Subject: [PATCH 19/25] Survival models make indicators --- R/surv_reg_data.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 9cd243e33..f333766d2 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -29,6 +29,13 @@ set_fit( ) ) +set_encoding( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "surv_reg", eng = "flexsurv", @@ -92,6 +99,13 @@ set_fit( ) ) +set_encoding( + model = "surv_reg", + eng = "survival", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "surv_reg", eng = "survival", From 115292dfb650ff58c1c73c3f93fad8d9eb73b4dc Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 13:41:16 -0600 Subject: [PATCH 20/25] Change kernlab to use formula interface, add indicator encoding --- R/svm_poly_data.R | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index f72ee8dc4..360deabe2 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -49,25 +49,39 @@ set_fit( eng = "kernlab", mode = "regression", value = list( - interface = "matrix", - protect = c("x", "y"), + interface = "formula", + protect = c("formula", "data"), func = c(pkg = "kernlab", fun = "ksvm"), defaults = list(kernel = "polydot") ) ) +set_encoding( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "svm_poly", eng = "kernlab", mode = "classification", value = list( - interface = "matrix", - protect = c("x", "y"), + interface = "formula", + protect = c("formula", "data"), func = c(pkg = "kernlab", fun = "ksvm"), defaults = list(kernel = "polydot") ) ) +set_encoding( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "svm_poly", eng = "kernlab", From 2e8d113ee8f889f92fbd8044086439575318aac2 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 14:18:13 -0600 Subject: [PATCH 21/25] Change svm_rbf (kernlab) to formula interface, add indicator encodings --- R/svm_rbf_data.R | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index ae58c6db6..d1c24de8a 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -41,25 +41,39 @@ set_fit( eng = "kernlab", mode = "regression", value = list( - interface = "matrix", - protect = c("x", "y"), + interface = "formula", + protect = c("formula", "data"), func = c(pkg = "kernlab", fun = "ksvm"), defaults = list(kernel = "rbfdot") ) ) +set_encoding( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "svm_rbf", eng = "kernlab", mode = "classification", value = list( - interface = "matrix", - protect = c("x", "y"), + interface = "formula", + protect = c("formula", "data"), func = c(pkg = "kernlab", fun = "ksvm"), defaults = list(kernel = "rbfdot") ) ) +set_encoding( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "svm_rbf", eng = "kernlab", From a42074956734166041b9b9c119e630765ed1f95b Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 19 May 2020 14:18:48 -0600 Subject: [PATCH 22/25] Update tests for kernlab formula interface --- tests/testthat/test_svm_poly.R | 16 ++++++++-------- tests/testthat/test_svm_rbf.R | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index 85c1b95db..c237dda15 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -17,8 +17,8 @@ test_that('primary arguments', { expect_equal( object = basic_kernlab$method$fit$args, expected = list( - x = expr(missing_arg()), - y = expr(missing_arg()), + formula = expr(missing_arg()), + data = expr(missing_arg()), kernel = "polydot" ) ) @@ -31,8 +31,8 @@ test_that('primary arguments', { expect_equal( object = degree_kernlab$method$fit$args, expected = list( - x = expr(missing_arg()), - y = expr(missing_arg()), + formula = expr(missing_arg()), + data = expr(missing_arg()), kernel = "polydot", kpar = degree_obj ) @@ -47,8 +47,8 @@ test_that('primary arguments', { expect_equal( object = degree_scale_kernlab$method$fit$args, expected = list( - x = expr(missing_arg()), - y = expr(missing_arg()), + formula = expr(missing_arg()), + data = expr(missing_arg()), kernel = "polydot", kpar = degree_scale_obj ) @@ -63,8 +63,8 @@ test_that('engine arguments', { expect_equal( object = translate(kernlab_cv, "kernlab")$method$fit$args, expected = list( - x = expr(missing_arg()), - y = expr(missing_arg()), + formula = expr(missing_arg()), + data = expr(missing_arg()), cross = new_empty_quosure(10), kernel = "polydot" ) diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index e9d4a8d29..2b71bb942 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -16,8 +16,8 @@ test_that('primary arguments', { expect_equal( object = basic_kernlab$method$fit$args, expected = list( - x = expr(missing_arg()), - y = expr(missing_arg()), + formula = expr(missing_arg()), + data = expr(missing_arg()), kernel = "rbfdot" ) ) @@ -30,8 +30,8 @@ test_that('primary arguments', { expect_equal( object = rbf_sigma_kernlab$method$fit$args, expected = list( - x = expr(missing_arg()), - y = expr(missing_arg()), + formula = expr(missing_arg()), + data = expr(missing_arg()), kernel = "rbfdot", kpar = rbf_sigma_obj ) @@ -46,8 +46,8 @@ test_that('engine arguments', { expect_equal( object = translate(kernlab_cv, "kernlab")$method$fit$args, expected = list( - x = expr(missing_arg()), - y = expr(missing_arg()), + formula = expr(missing_arg()), + data = expr(missing_arg()), cross = new_empty_quosure(10), kernel = "rbfdot" ) From 3c8481e853ce3a0a8d22b239556f52e96b864fc1 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 26 May 2020 12:15:19 -0600 Subject: [PATCH 23/25] Spark *always* makes indicator variables, fix dependency for Spark + decision tree --- R/decision_tree_data.R | 6 +++--- R/rand_forest_data.R | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index 40ec34d75..3e3ce898a 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -232,7 +232,7 @@ set_pred( set_model_engine("decision_tree", "classification", "spark") set_model_engine("decision_tree", "regression", "spark") -set_dependency("decision_tree", "spark", "spark") +set_dependency("decision_tree", "spark", "sparklyr") set_model_arg( model = "decision_tree", @@ -270,7 +270,7 @@ set_encoding( model = "decision_tree", eng = "spark", mode = "regression", - options = list(predictor_indicators = FALSE) + options = list(predictor_indicators = TRUE) ) set_fit( @@ -291,7 +291,7 @@ set_encoding( model = "decision_tree", eng = "spark", mode = "classification", - options = list(predictor_indicators = FALSE) + options = list(predictor_indicators = TRUE) ) set_pred( diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index a69b312ee..a4c81b669 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -507,7 +507,7 @@ set_encoding( model = "rand_forest", eng = "spark", mode = "classification", - options = list(predictor_indicators = FALSE) + options = list(predictor_indicators = TRUE) ) set_fit( @@ -527,7 +527,7 @@ set_encoding( model = "rand_forest", eng = "spark", mode = "regression", - options = list(predictor_indicators = FALSE) + options = list(predictor_indicators = TRUE) ) set_pred( From 1160e1e11170e2df9622f1485f322a2baac09590 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 26 May 2020 12:24:24 -0600 Subject: [PATCH 24/25] Fix function used with Spark decision tree for regression --- R/decision_tree_data.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index 3e3ce898a..9eaf8bc69 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -260,7 +260,7 @@ set_fit( interface = "formula", data = c(formula = "formula", data = "x"), protect = c("x", "formula"), - func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_regressor"), defaults = list(seed = expr(sample.int(10 ^ 5, 1))) ) From 534987e96245314889cdfe873d8e73224a0e158d Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 26 May 2020 19:40:55 -0600 Subject: [PATCH 25/25] Change to predictor_indicators = FALSE for MARS models --- R/mars_data.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/mars_data.R b/R/mars_data.R index 92154552f..401f04cb7 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -51,7 +51,7 @@ set_encoding( model = "mars", eng = "earth", mode = "regression", - options = list(predictor_indicators = TRUE) + options = list(predictor_indicators = FALSE) ) set_fit( @@ -70,7 +70,7 @@ set_encoding( model = "mars", eng = "earth", mode = "classification", - options = list(predictor_indicators = TRUE) + options = list(predictor_indicators = FALSE) ) set_pred(