diff --git a/NAMESPACE b/NAMESPACE index 8f388d332..3bd93bfeb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -110,6 +110,7 @@ export(fit_control) export(fit_xy) export(fit_xy.model_spec) export(get_dependency) +export(get_encoding) export(get_fit) export(get_from_env) export(get_model_env) @@ -146,6 +147,7 @@ export(repair_call) export(rpart_train) export(set_args) export(set_dependency) +export(set_encoding) export(set_engine) export(set_env_val) export(set_fit) diff --git a/R/aaa_models.R b/R/aaa_models.R index ad9235d8e..65b74d3f7 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -323,6 +323,11 @@ check_interface_val <- function(x) { #' below, depending on context. #' @param pre,post Optional functions for pre- and post-processing of prediction #' results. +#' @param options A list of options for engine-specific encodings. Currently, +#' the option implemented is `predictor_indicators` which tells `parsnip` +#' whether the pre-processing should make indicator/dummy variables from factor +#' predictors. This only affects cases when [fit.model_spec()] is used and the +#' underlying model has an x/y interface. #' @param ... Optional arguments that should be passed into the `args` slot for #' prediction objects. #' @keywords internal @@ -780,3 +785,77 @@ pred_value_template <- function(pre = NULL, post = NULL, func, ...) { list(pre = pre, post = post, func = func, args = list(...)) } +# ------------------------------------------------------------------------------ + +check_encodings <- function(x) { + if (!is.list(x)) { + rlang::abort("`values` should be a list.") + } + req_args <- list(predictor_indicators = TRUE) + + missing_args <- setdiff(names(req_args), names(x)) + if (length(missing_args) > 0) { + rlang::abort( + glue::glue( + "The values passed to `set_encoding()` are missing arguments: ", + paste0("'", missing_args, "'", collapse = ", ") + ) + ) + } + extra_args <- setdiff(names(x), names(req_args)) + if (length(extra_args) > 0) { + rlang::abort( + glue::glue( + "The values passed to `set_encoding()` had extra arguments: ", + paste0("'", extra_args, "'", collapse = ", ") + ) + ) + } + invisible(x) +} + +#' @export +#' @rdname set_new_model +#' @keywords internal +set_encoding <- function(model, mode, eng, options) { + check_model_exists(model) + check_eng_val(eng) + check_mode_val(mode) + check_encodings(options) + + keys <- tibble::tibble(model = model, engine = eng, mode = mode) + options <- tibble::as_tibble(options) + new_values <- dplyr::bind_cols(keys, options) + + + current_db_list <- ls(envir = get_model_env()) + nm <- paste(model, "encoding", sep = "_") + if (any(current_db_list == nm)) { + current <- get_from_env(nm) + dup_check <- + current %>% + dplyr::inner_join(new_values, by = c("model", "engine", "mode", "predictor_indicators")) + if (nrow(dup_check)) { + rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings.")) + } + + } else { + current <- NULL + } + + db_values <- dplyr::bind_rows(current, new_values) + set_env_val(nm, db_values) + + invisible(NULL) +} + + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_encoding <- function(model) { + check_model_exists(model) + nm <- paste0(model, "_encoding") + rlang::env_get(get_model_env(), nm) +} + diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 25924454d..368b1cba5 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -87,6 +87,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "boost_tree", eng = "xgboost", @@ -125,6 +132,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "boost_tree", eng = "xgboost", @@ -221,6 +235,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "boost_tree", eng = "C5.0", @@ -344,6 +365,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "boost_tree", eng = "spark", @@ -357,6 +385,13 @@ set_fit( ) ) +set_encoding( + model = "boost_tree", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "boost_tree", eng = "spark", diff --git a/R/convert_data.R b/R/convert_data.R index 6496a1ce9..fe02afce4 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -15,7 +15,7 @@ #' @importFrom stats .checkMFClasses .getXlevels delete.response #' @importFrom stats model.offset model.weights na.omit na.pass -convert_form_to_xy_fit <-function( +convert_form_to_xy_fit <- function( formula, data, ..., diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index 4f0d46160..9eaf8bc69 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -48,6 +48,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "rpart", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "decision_tree", eng = "rpart", @@ -60,6 +67,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "rpart", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "decision_tree", eng = "rpart", @@ -158,6 +172,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "decision_tree", eng = "C5.0", @@ -211,7 +232,7 @@ set_pred( set_model_engine("decision_tree", "classification", "spark") set_model_engine("decision_tree", "regression", "spark") -set_dependency("decision_tree", "spark", "spark") +set_dependency("decision_tree", "spark", "sparklyr") set_model_arg( model = "decision_tree", @@ -239,12 +260,19 @@ set_fit( interface = "formula", data = c(formula = "formula", data = "x"), protect = c("x", "formula"), - func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_regressor"), defaults = list(seed = expr(sample.int(10 ^ 5, 1))) ) ) +set_encoding( + model = "decision_tree", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "decision_tree", eng = "spark", @@ -259,6 +287,13 @@ set_fit( ) ) +set_encoding( + model = "decision_tree", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "decision_tree", eng = "spark", diff --git a/R/fit.R b/R/fit.R index 0967026cc..6b1acc10f 100644 --- a/R/fit.R +++ b/R/fit.R @@ -103,7 +103,7 @@ fit.model_spec <- eng_vals <- possible_engines(object) object$engine <- eng_vals[1] if (control$verbosity > 0) { - rlang::warn("Engine set to `{object$engine}`.") + rlang::warn(glue::glue("Engine set to `{object$engine}`.")) } } diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 49a069cf8..f6940460e 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -103,12 +103,17 @@ xy_xy <- function(object, env, control, target = "none", ...) { form_xy <- function(object, control, env, target = "none", ...) { + indicators <- get_encoding(class(object)[1]) %>% + dplyr::filter(mode == object$mode, + engine == object$engine) %>% + dplyr::pull(predictor_indicators) + data_obj <- convert_form_to_xy_fit( formula = env$formula, data = env$data, ..., - composition = target - # indicators + composition = target, + indicators = indicators ) env$x <- data_obj$x env$y <- data_obj$y diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index b5460e66e..2e33efe13 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -19,6 +19,13 @@ set_fit( ) ) +set_encoding( + model = "linear_reg", + eng = "lm", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "linear_reg", eng = "lm", @@ -102,6 +109,25 @@ set_pred( set_model_engine("linear_reg", "regression", "glmnet") set_dependency("linear_reg", "glmnet", "glmnet") +set_fit( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "gaussian") + ) +) + +set_encoding( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_model_arg( model = "linear_reg", eng = "glmnet", @@ -120,18 +146,6 @@ set_model_arg( has_submodel = FALSE ) -set_fit( - model = "linear_reg", - eng = "glmnet", - mode = "regression", - value = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = list(family = "gaussian") - ) -) - set_pred( model = "linear_reg", eng = "glmnet", @@ -183,6 +197,13 @@ set_fit( ) ) +set_encoding( + model = "linear_reg", + eng = "stan", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "linear_reg", eng = "stan", @@ -287,6 +308,26 @@ set_pred( set_model_engine("linear_reg", "regression", "spark") set_dependency("linear_reg", "spark", "sparklyr") +set_fit( + model = "linear_reg", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + data = c(formula = "formula", data = "x"), + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_linear_regression"), + defaults = list() + ) +) + +set_encoding( + model = "linear_reg", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_model_arg( model = "linear_reg", eng = "spark", @@ -305,20 +346,6 @@ set_model_arg( has_submodel = FALSE ) - -set_fit( - model = "linear_reg", - eng = "spark", - mode = "regression", - value = list( - interface = "formula", - data = c(formula = "formula", data = "x"), - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_linear_regression"), - defaults = list() - ) -) - set_pred( model = "linear_reg", eng = "spark", @@ -343,15 +370,6 @@ set_model_engine("linear_reg", "regression", "keras") set_dependency("linear_reg", "keras", "keras") set_dependency("linear_reg", "keras", "magrittr") -set_model_arg( - model = "linear_reg", - eng = "keras", - parsnip = "penalty", - original = "penalty", - func = list(pkg = "dials", fun = "penalty"), - has_submodel = FALSE -) - set_fit( model = "linear_reg", eng = "keras", @@ -364,6 +382,22 @@ set_fit( ) ) +set_encoding( + model = "linear_reg", + eng = "keras", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + +set_model_arg( + model = "linear_reg", + eng = "keras", + parsnip = "penalty", + original = "penalty", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = FALSE +) + set_pred( model = "linear_reg", eng = "keras", diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 306bae8b6..0aa4b3e74 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -19,6 +19,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "glm", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "glm", @@ -121,6 +128,25 @@ set_pred( set_model_engine("logistic_reg", "classification", "glmnet") set_dependency("logistic_reg", "glmnet", "glmnet") +set_fit( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "binomial") + ) +) + +set_encoding( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_model_arg( model = "logistic_reg", eng = "glmnet", @@ -139,19 +165,6 @@ set_model_arg( has_submodel = FALSE ) -set_fit( - model = "logistic_reg", - eng = "glmnet", - mode = "classification", - value = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = list(family = "binomial") - ) -) - - set_pred( model = "logistic_reg", eng = "glmnet", @@ -246,6 +259,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "spark", @@ -307,6 +327,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "keras", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "keras", @@ -364,6 +391,13 @@ set_fit( ) ) +set_encoding( + model = "logistic_reg", + eng = "stan", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "logistic_reg", eng = "stan", diff --git a/R/mars_data.R b/R/mars_data.R index 7ec4b363d..401f04cb7 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -47,6 +47,13 @@ set_fit( ) ) +set_encoding( + model = "mars", + eng = "earth", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "mars", eng = "earth", @@ -59,6 +66,13 @@ set_fit( ) ) +set_encoding( + model = "mars", + eng = "earth", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "mars", eng = "earth", diff --git a/R/mlp_data.R b/R/mlp_data.R index 81ad6dc83..73bf4df07 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -65,6 +65,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "keras", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "mlp", eng = "keras", @@ -77,6 +84,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "keras", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "mlp", eng = "keras", @@ -191,6 +205,7 @@ set_model_arg( func = list(pkg = "dials", fun = "penalty"), has_submodel = FALSE ) + set_model_arg( model = "mlp", eng = "nnet", @@ -199,6 +214,7 @@ set_model_arg( func = list(pkg = "dials", fun = "epochs"), has_submodel = FALSE ) + set_fit( model = "mlp", eng = "nnet", @@ -211,6 +227,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "nnet", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "mlp", eng = "nnet", @@ -223,6 +246,13 @@ set_fit( ) ) +set_encoding( + model = "mlp", + eng = "nnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "mlp", eng = "nnet", diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 5ef051c9c..0eb392c53 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -37,6 +37,12 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) set_pred( model = "multinom_reg", @@ -129,6 +135,13 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "multinom_reg", eng = "spark", @@ -192,6 +205,13 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "keras", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "multinom_reg", eng = "keras", @@ -255,6 +275,12 @@ set_fit( ) ) +set_encoding( + model = "multinom_reg", + eng = "nnet", + mode = "classification", + options = list(predictor_indicators = TRUE) +) set_pred( model = "multinom_reg", diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 61cad97c8..695c09cdc 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -47,6 +47,13 @@ set_fit( ) ) +set_encoding( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "nearest_neighbor", eng = "kknn", @@ -59,6 +66,13 @@ set_fit( ) ) +set_encoding( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "nearest_neighbor", eng = "kknn", diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index c8a29c41c..aa6345879 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -21,6 +21,13 @@ set_fit( ) ) +set_encoding( + model = "null_model", + eng = "parsnip", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "null_model", eng = "parsnip", @@ -33,6 +40,13 @@ set_fit( ) ) +set_encoding( + model = "null_model", + eng = "parsnip", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "null_model", eng = "parsnip", diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 70b872ae4..a4c81b669 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -132,6 +132,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "ranger", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "rand_forest", eng = "ranger", @@ -149,6 +156,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "ranger", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "rand_forest", eng = "ranger", @@ -338,6 +352,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "rand_forest", eng = "randomForest", @@ -351,6 +372,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "rand_forest", eng = "randomForest", @@ -475,6 +503,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "spark", + mode = "classification", + options = list(predictor_indicators = TRUE) +) + set_fit( model = "rand_forest", eng = "spark", @@ -488,6 +523,13 @@ set_fit( ) ) +set_encoding( + model = "rand_forest", + eng = "spark", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "rand_forest", eng = "spark", diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 9cd243e33..f333766d2 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -29,6 +29,13 @@ set_fit( ) ) +set_encoding( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "surv_reg", eng = "flexsurv", @@ -92,6 +99,13 @@ set_fit( ) ) +set_encoding( + model = "surv_reg", + eng = "survival", + mode = "regression", + options = list(predictor_indicators = TRUE) +) + set_pred( model = "surv_reg", eng = "survival", diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index cf6b85ed4..e0eb3a867 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -57,6 +57,13 @@ set_fit( ) ) +set_encoding( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "svm_poly", eng = "kernlab", @@ -70,6 +77,13 @@ set_fit( ) ) +set_encoding( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "svm_poly", eng = "kernlab", diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index c6eb11a9d..a210d5703 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -49,6 +49,13 @@ set_fit( ) ) +set_encoding( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "svm_rbf", eng = "kernlab", @@ -62,6 +69,13 @@ set_fit( ) ) +set_encoding( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "svm_rbf", eng = "kernlab", @@ -178,6 +192,14 @@ set_fit( ) ) ) + +set_encoding( + model = "svm_rbf", + eng = "liquidSVM", + mode = "regression", + options = list(predictor_indicators = FALSE) +) + set_fit( model = "svm_rbf", eng = "liquidSVM", @@ -192,6 +214,14 @@ set_fit( ) ) ) + +set_encoding( + model = "svm_rbf", + eng = "liquidSVM", + mode = "classification", + options = list(predictor_indicators = FALSE) +) + set_pred( model = "svm_rbf", eng = "liquidSVM", diff --git a/man/set_new_model.Rd b/man/set_new_model.Rd index 385c3f83b..8d9642cf0 100644 --- a/man/set_new_model.Rd +++ b/man/set_new_model.Rd @@ -13,6 +13,8 @@ \alias{get_pred_type} \alias{show_model_info} \alias{pred_value_template} +\alias{set_encoding} +\alias{get_encoding} \title{Tools to Register Models} \usage{ set_new_model(model) @@ -38,6 +40,10 @@ get_pred_type(model, type) show_model_info(model) pred_value_template(pre = NULL, post = NULL, func, ...) + +set_encoding(model, mode, eng, options) + +get_encoding(model) } \arguments{ \item{model}{A single character string for the model type (e.g. @@ -79,6 +85,12 @@ results.} \item{...}{Optional arguments that should be passed into the \code{args} slot for prediction objects.} +\item{options}{A list of options for engine-specific encodings. Currently, +the option implemented is \code{predictor_indicators} which tells \code{parsnip} +whether the pre-processing should make indicator/dummy variables from factor +predictors. This only affects cases when \code{\link[=fit.model_spec]{fit.model_spec()}} is used and the +underlying model has an x/y interface.} + \item{arg}{A single character string for the model argument name.} \item{fit_obj}{A list with elements \code{interface}, \code{protect}, diff --git a/tests/testthat/test_convert_data.R b/tests/testthat/test_convert_data.R index 51cb4e221..c58364be2 100644 --- a/tests/testthat/test_convert_data.R +++ b/tests/testthat/test_convert_data.R @@ -168,6 +168,21 @@ test_that("numeric y and mixed x", { ) }) +test_that("mixed x, no dummies, compare to a model that does not create dummies", { + expected <- rpart::rpart(rate ~ ., data = Puromycin) + data_classes <- attr(expected$terms, "dataClasses")[2:3] + observed <- parsnip:::convert_form_to_xy_fit(rate ~ ., data = Puromycin, indicators = FALSE) + expect_equal(names(data_classes), names(observed$x)) + expect_equal(unname(data_classes), c("numeric", "factor")) + expect_s3_class(observed$x$state, "factor") + expect_equivalent(Puromycin$rate, observed$y) + expect_equal(expected$terms, observed$terms) + + expect_null(observed$weights) + expect_null(observed$offset) +}) + + test_that("numeric y and mixed x, omit missing data", { expected <- lm(rate ~ ., data = Puromycin_miss, x = TRUE, y = TRUE) observed <- parsnip:::convert_form_to_xy_fit(rate ~ ., data = Puromycin_miss)