diff --git a/NEWS.md b/NEWS.md index 54eaa7b9c..0da233efc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,14 +1,21 @@ # parsnip (development version) +## Model Specification Changes + +* A model function (`gen_additive_mod()`) was added for generalized additive models. + * Each model now has a default engine that is used when the model is defined. The default for each model is listed in the help documents. This also adds functionality to declare an engine in the model specification function. `set_engine()` is still required if engine-specific arguments need to be added. (#513) +* parsnip now checks for a valid combination of engine and mode (#529) + * The default engine for `multinom_reg()` was changed to `nnet`. +## Other Changes + * The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508). * Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510). -* A model function (`gen_additive_mod()`) was added for generalized additive models. # parsnip 0.1.6 diff --git a/R/aaa_models.R b/R/aaa_models.R index b29c62467..42bb2487a 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -1,5 +1,7 @@ # Initialize model environments +all_modes <- c("classification", "regression", "censored regression") + # ------------------------------------------------------------------------------ ## Rules about model-related information @@ -23,10 +25,9 @@ # ------------------------------------------------------------------------------ - parsnip <- rlang::new_environment() parsnip$models <- NULL -parsnip$modes <- c("regression", "classification", "unknown") +parsnip$modes <- c(all_modes, "unknown") # ------------------------------------------------------------------------------ @@ -134,25 +135,119 @@ check_mode_val <- function(mode) { } -stop_incompatible_mode <- function(spec_modes) { +stop_incompatible_mode <- function(spec_modes, eng = NULL, cls = NULL) { + if (is.null(eng) & is.null(cls)) { + msg <- "Available modes are: " + } + if (!is.null(eng) & is.null(cls)) { + msg <- glue::glue("Available modes for engine {eng} are: ") + } + if (is.null(eng) & !is.null(cls)) { + msg <- glue::glue("Available modes for model type {cls} are: ") + } + if (!is.null(eng) & !is.null(cls)) { + msg <- glue::glue("Available modes for model type {cls} with engine {eng} are: ") + } + msg <- glue::glue( - "Available modes are: ", + msg, glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") ) rlang::abort(msg) } -# check if class and mode are compatible -check_spec_mode_val <- function(cls, mode) { - spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) +stop_incompatible_engine <- function(spec_engs, mode) { + msg <- glue::glue( + "Available engines for mode {mode} are: ", + glue::glue_collapse(glue::glue("'{spec_engs}'"), sep = ", ") + ) + rlang::abort(msg) +} + +stop_missing_engine <- function(cls) { + info <- + get_from_env(cls) %>% + dplyr::group_by(mode) %>% + dplyr::summarize(msg = paste0(unique(mode), " {", + paste0(unique(engine), collapse = ", "), + "}"), + .groups = "drop") + if (nrow(info) == 0) { + rlang::abort(paste0("No known engines for `", cls, "()`.")) + } + msg <- paste0(info$msg, collapse = ", ") + msg <- paste("Missing engine. Possible mode/engine combinations are:", msg) + rlang::abort(msg) +} + + +# check if class and mode and engine are compatible +check_spec_mode_engine_val <- function(cls, eng, mode) { + all_modes <- c("unknown", all_modes) + if (!(mode %in% all_modes)) { + rlang::abort(paste0("'", mode, "' is not a known mode.")) + } + + model_info <- rlang::env_get(get_model_env(), cls) + + # Cases where the model definition is in parsnip but all of the engines + # are contained in a different package + if (nrow(model_info) == 0) { + check_mode_with_no_engine(cls, mode) + return(invisible(NULL)) + } + + # ------------------------------------------------------------------------------ + # First check engine against any mode for the given model class + + spec_engs <- model_info$engine + # engine is allowed to be NULL + if (!is.null(eng) && !(eng %in% spec_engs)) { + rlang::abort( + paste0( + "Engine '", eng, "' is not supported for `", cls, "()`. See ", + "`show_engines('", cls, "')`." + ) + ) + } + + # ---------------------------------------------------------------------------- + # Check modes based on model and engine + + spec_modes <- model_info$mode + if (!is.null(eng)) { + spec_modes <- spec_modes[model_info$engine == eng] + } + spec_modes <- unique(c("unknown", spec_modes)) + if (is.null(mode) || length(mode) > 1) { - stop_incompatible_mode(spec_modes) + stop_incompatible_mode(spec_modes, eng) } else if (!(mode %in% spec_modes)) { - stop_incompatible_mode(spec_modes) + stop_incompatible_mode(spec_modes, eng) } + + # ---------------------------------------------------------------------------- + # Check engine based on model and model + + # How check for compatibility with the chosen mode (if any) + if (!is.null(mode) && mode != "unknown") { + spec_engs <- spec_engs[model_info$mode == mode] + } + spec_engs <- unique(spec_engs) + if (!is.null(eng) && !(eng %in% spec_engs)) { + stop_incompatible_engine(spec_engs, mode) + } + invisible(NULL) } +check_mode_with_no_engine <- function(cls, mode) { + spec_modes <- get_from_env(paste0(cls, "_modes")) + if (!(mode %in% spec_modes)) { + stop_incompatible_mode(spec_modes, cls = cls) + } +} + check_engine_val <- function(eng) { if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).") @@ -625,8 +720,7 @@ get_dependency <- function(model) { set_fit <- function(model, mode, eng, value) { check_model_exists(model) check_eng_val(eng) - check_mode_val(mode) - check_engine_val(eng) + check_spec_mode_engine_val(model, eng, mode) check_fit_info(value) current <- get_model_env() @@ -692,8 +786,7 @@ get_fit <- function(model) { set_pred <- function(model, mode, eng, type, value) { check_model_exists(model) check_eng_val(eng) - check_mode_val(mode) - check_engine_val(eng) + check_spec_mode_engine_val(model, eng, mode) check_pred_info(value, type) current <- get_model_env() diff --git a/R/arguments.R b/R/arguments.R index 8c7cb735b..76d99b626 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -79,9 +79,9 @@ set_mode <- function(object, mode) { cls <- class(object)[1] if (rlang::is_missing(mode)) { spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) - stop_incompatible_mode(spec_modes) + stop_incompatible_mode(spec_modes, cls = cls) } - check_spec_mode_val(cls, mode) + check_spec_mode_engine_val(cls, object$engine, mode) object$mode <- mode object } diff --git a/R/engines.R b/R/engines.R index 2054259ab..aad076e11 100644 --- a/R/engines.R +++ b/R/engines.R @@ -10,25 +10,6 @@ possible_engines <- function(object, ...) { unique(engs$engine) } -stop_incompatible_engine <- function(avail_eng) { - msg <- glue::glue( - "Available engines are: ", - glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ") - ) - rlang::abort(msg) -} - -check_engine <- function(object) { - avail_eng <- possible_engines(object) - eng <- object$engine - if (is.null(eng) || length(eng) > 1) { - stop_incompatible_engine(avail_eng) - } else if (!(eng %in% avail_eng)) { - stop_incompatible_engine(avail_eng) - } - object -} - # ------------------------------------------------------------------------------ shhhh <- function(x) @@ -90,16 +71,16 @@ load_libs <- function(x, quiet, attach = FALSE) { #' translate(mod, engine = "glmnet") #' @export set_engine <- function(object, engine, ...) { + mod_type <- class(object)[1] if (!inherits(object, "model_spec")) { rlang::abort("`object` should have class 'model_spec'.") } if (rlang::is_missing(engine)) { - avail_eng <- possible_engines(object) - stop_incompatible_engine(avail_eng) + stop_missing_engine(mod_type) } object$engine <- engine - object <- check_engine(object) + check_spec_mode_engine_val(mod_type, object$engine, object$mode) if (object$engine == "liquidSVM") { lifecycle::deprecate_soft( @@ -109,7 +90,7 @@ set_engine <- function(object, engine, ...) { } new_model_spec( - cls = class(object)[1], + cls = mod_type, args = object$args, eng_args = enquos(...), mode = object$mode, diff --git a/R/misc.R b/R/misc.R index 430b4a9fe..274c066f0 100644 --- a/R/misc.R +++ b/R/misc.R @@ -23,9 +23,6 @@ check_empty_ellipse <- function (...) { terms } -all_modes <- c("classification", "regression", "censored regression") - - deparserizer <- function(x, limit = options()$width - 10) { x <- deparse(x, width.cutoff = limit) x <- gsub("^ ", "", x) @@ -192,7 +189,7 @@ update_dot_check <- function(...) { #' @rdname add_on_exports new_model_spec <- function(cls, args, eng_args, mode, method, engine) { - check_spec_mode_val(cls, mode) + check_spec_mode_engine_val(cls, engine, mode) out <- list(args = args, eng_args = eng_args, mode = mode, method = method, engine = engine) diff --git a/R/translate.R b/R/translate.R index 4c2064db6..1172b3b29 100644 --- a/R/translate.R +++ b/R/translate.R @@ -59,14 +59,15 @@ translate.default <- function(x, engine = x$engine, ...) { mod_name <- specific_model(x) x$engine <- engine - x <- check_engine(x) - if (x$mode == "unknown") { rlang::abort("Model code depends on the mode; please specify one.") } - if (is.null(x$method)) + check_spec_mode_engine_val(class(x)[1], x$engine, x$mode) + + if (is.null(x$method)) { x$method <- get_model_spec(mod_name, x$mode, engine) + } arg_key <- get_args(mod_name, engine) @@ -174,7 +175,7 @@ deharmonize <- function(args, key) { add_methods <- function(x, engine) { x$engine <- engine - x <- check_engine(x) + check_spec_mode_engine_val(class(x)[1], x$engine, x$mode) x$method <- get_model_spec(specific_model(x), x$mode, x$engine) x } diff --git a/man/details_gen_additive_mod_mgcv.Rd b/man/details_gen_additive_mod_mgcv.Rd index 53ba10330..1c9d98f34 100644 --- a/man/details_gen_additive_mod_mgcv.Rd +++ b/man/details_gen_additive_mod_mgcv.Rd @@ -65,7 +65,7 @@ gen_additive_mod() \%>\% fit(mpg ~ wt + gear + cyl + s(disp, k = 10), data = mtcars) }\if{html}{\out{}}\preformatted{## parsnip model object ## -## Fit time: 21ms +## Fit time: 20ms ## ## Family: gaussian ## Link function: identity diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd index f1e30a0ea..27a2bcf3b 100644 --- a/man/extract-parsnip.Rd +++ b/man/extract-parsnip.Rd @@ -26,7 +26,7 @@ not exist yet, an error is thrown. \item \code{extract_spec_parsnip()} returns the parsnip model specification. \item \code{extract_fit_engine()} returns the engine specific fit embedded within a parsnip model fit. For example, when using \code{\link[=linear_reg]{linear_reg()}} -with the \code{"lm"} engine, this would return the underlying \code{lm} object. +with the \code{"lm"} engine, this returns the underlying \code{lm} object. } } \details{ diff --git a/man/rmd/boost_tree_C5.0.Rmd b/man/rmd/boost_tree_C5.0.Rmd index cf43aa369..19c772570 100644 --- a/man/rmd/boost_tree_C5.0.Rmd +++ b/man/rmd/boost_tree_C5.0.Rmd @@ -13,7 +13,7 @@ defaults <- param <- boost_tree() %>% set_engine("C5.0") %>% - set_mode("regression") %>% + set_mode("classification") %>% tunable() %>% dplyr::select(-source, -component, -component_id, parsnip = name) %>% dplyr::mutate( diff --git a/man/rmd/decision_tree_C5.0.Rmd b/man/rmd/decision_tree_C5.0.Rmd index 7147fde64..99b36c20d 100644 --- a/man/rmd/decision_tree_C5.0.Rmd +++ b/man/rmd/decision_tree_C5.0.Rmd @@ -13,7 +13,7 @@ defaults <- param <- decision_tree() %>% set_engine("C5.0") %>% - set_mode("regression") %>% + set_mode("classification") %>% tunable() %>% dplyr::select(-source, -component, -component_id, parsnip = name) %>% dplyr::mutate( diff --git a/tests/testthat/test_args_and_modes.R b/tests/testthat/test_args_and_modes.R index be58c7933..80c25d73c 100644 --- a/tests/testthat/test_args_and_modes.R +++ b/tests/testthat/test_args_and_modes.R @@ -49,6 +49,59 @@ test_that('pipe engine', { test_that("can't set a mode that isn't allowed by the model spec", { expect_error( set_mode(linear_reg(), "classification"), - "Available modes are:" + "Available modes" ) }) + + + +test_that("unavailable modes for an engine and vice-versa", { + expect_error( + decision_tree() %>% + set_mode("regression") %>% + set_engine("C5.0"), + "Available modes for engine C5" + ) + expect_error( + decision_tree() %>% + set_engine("C5.0") %>% + set_mode("regression"), + "Available modes for engine C5" + ) + + expect_error( + decision_tree(engine = NULL) %>% + set_engine("C5.0") %>% + set_mode("regression"), + "Available modes for engine C5" + ) + + expect_error( + decision_tree(engine = NULL)%>% + set_mode("regression") %>% + set_engine("C5.0"), + "Available modes for engine C5" + ) + + expect_error( + proportional_hazards() %>% set_mode("regression"), + "Available modes for model type proportional_hazards" + ) + + expect_error( + linear_reg() %>% set_mode(), + "Available modes for model type linear_reg" + ) + + expect_error( + linear_reg() %>% set_engine(), + "Missing engine" + ) + + expect_error( + proportional_hazards() %>% set_engine(), + "No known engines for" + ) +}) + +