Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# parsnip (development version)

## Model Specification Changes

* A model function (`gen_additive_mod()`) was added for generalized additive models.

* Each model now has a default engine that is used when the model is defined. The default for each model is listed in the help documents. This also adds functionality to declare an engine in the model specification function. `set_engine()` is still required if engine-specific arguments need to be added. (#513)

* parsnip now checks for a valid combination of engine and mode (#529)

* The default engine for `multinom_reg()` was changed to `nnet`.

## Other Changes

* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).

* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).

* A model function (`gen_additive_mod()`) was added for generalized additive models.

# parsnip 0.1.6

Expand Down
119 changes: 106 additions & 13 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Initialize model environments

all_modes <- c("classification", "regression", "censored regression")

# ------------------------------------------------------------------------------

## Rules about model-related information
Expand All @@ -23,10 +25,9 @@

# ------------------------------------------------------------------------------


parsnip <- rlang::new_environment()
parsnip$models <- NULL
parsnip$modes <- c("regression", "classification", "unknown")
parsnip$modes <- c(all_modes, "unknown")

# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -134,25 +135,119 @@ check_mode_val <- function(mode) {
}


stop_incompatible_mode <- function(spec_modes) {
stop_incompatible_mode <- function(spec_modes, eng = NULL, cls = NULL) {
if (is.null(eng) & is.null(cls)) {
msg <- "Available modes are: "
}
if (!is.null(eng) & is.null(cls)) {
msg <- glue::glue("Available modes for engine {eng} are: ")
}
if (is.null(eng) & !is.null(cls)) {
msg <- glue::glue("Available modes for model type {cls} are: ")
}
if (!is.null(eng) & !is.null(cls)) {
msg <- glue::glue("Available modes for model type {cls} with engine {eng} are: ")
}

msg <- glue::glue(
"Available modes are: ",
msg,
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
)
rlang::abort(msg)
}

# check if class and mode are compatible
check_spec_mode_val <- function(cls, mode) {
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
stop_incompatible_engine <- function(spec_engs, mode) {
msg <- glue::glue(
"Available engines for mode {mode} are: ",
glue::glue_collapse(glue::glue("'{spec_engs}'"), sep = ", ")
)
rlang::abort(msg)
}

stop_missing_engine <- function(cls) {
info <-
get_from_env(cls) %>%
dplyr::group_by(mode) %>%
dplyr::summarize(msg = paste0(unique(mode), " {",
paste0(unique(engine), collapse = ", "),
"}"),
.groups = "drop")
if (nrow(info) == 0) {
rlang::abort(paste0("No known engines for `", cls, "()`."))
}
msg <- paste0(info$msg, collapse = ", ")
msg <- paste("Missing engine. Possible mode/engine combinations are:", msg)
rlang::abort(msg)
}


# check if class and mode and engine are compatible
check_spec_mode_engine_val <- function(cls, eng, mode) {
all_modes <- c("unknown", all_modes)
if (!(mode %in% all_modes)) {
rlang::abort(paste0("'", mode, "' is not a known mode."))
}

model_info <- rlang::env_get(get_model_env(), cls)

# Cases where the model definition is in parsnip but all of the engines
# are contained in a different package
if (nrow(model_info) == 0) {
check_mode_with_no_engine(cls, mode)
return(invisible(NULL))
}

# ------------------------------------------------------------------------------
# First check engine against any mode for the given model class

spec_engs <- model_info$engine
# engine is allowed to be NULL
if (!is.null(eng) && !(eng %in% spec_engs)) {
rlang::abort(
paste0(
"Engine '", eng, "' is not supported for `", cls, "()`. See ",
"`show_engines('", cls, "')`."
)
)
}

# ----------------------------------------------------------------------------
# Check modes based on model and engine

spec_modes <- model_info$mode
if (!is.null(eng)) {
spec_modes <- spec_modes[model_info$engine == eng]
}
spec_modes <- unique(c("unknown", spec_modes))

if (is.null(mode) || length(mode) > 1) {
stop_incompatible_mode(spec_modes)
stop_incompatible_mode(spec_modes, eng)
} else if (!(mode %in% spec_modes)) {
stop_incompatible_mode(spec_modes)
stop_incompatible_mode(spec_modes, eng)
}

# ----------------------------------------------------------------------------
# Check engine based on model and model

# How check for compatibility with the chosen mode (if any)
if (!is.null(mode) && mode != "unknown") {
spec_engs <- spec_engs[model_info$mode == mode]
}
spec_engs <- unique(spec_engs)
if (!is.null(eng) && !(eng %in% spec_engs)) {
stop_incompatible_engine(spec_engs, mode)
}

invisible(NULL)
}

check_mode_with_no_engine <- function(cls, mode) {
spec_modes <- get_from_env(paste0(cls, "_modes"))
if (!(mode %in% spec_modes)) {
stop_incompatible_mode(spec_modes, cls = cls)
}
}

check_engine_val <- function(eng) {
if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng))
rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).")
Expand Down Expand Up @@ -625,8 +720,7 @@ get_dependency <- function(model) {
set_fit <- function(model, mode, eng, value) {
check_model_exists(model)
check_eng_val(eng)
check_mode_val(mode)
check_engine_val(eng)
check_spec_mode_engine_val(model, eng, mode)
check_fit_info(value)

current <- get_model_env()
Expand Down Expand Up @@ -692,8 +786,7 @@ get_fit <- function(model) {
set_pred <- function(model, mode, eng, type, value) {
check_model_exists(model)
check_eng_val(eng)
check_mode_val(mode)
check_engine_val(eng)
check_spec_mode_engine_val(model, eng, mode)
check_pred_info(value, type)

current <- get_model_env()
Expand Down
4 changes: 2 additions & 2 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ set_mode <- function(object, mode) {
cls <- class(object)[1]
if (rlang::is_missing(mode)) {
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
stop_incompatible_mode(spec_modes)
stop_incompatible_mode(spec_modes, cls = cls)
}
check_spec_mode_val(cls, mode)
check_spec_mode_engine_val(cls, object$engine, mode)
object$mode <- mode
object
}
Expand Down
27 changes: 4 additions & 23 deletions R/engines.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,6 @@ possible_engines <- function(object, ...) {
unique(engs$engine)
}

stop_incompatible_engine <- function(avail_eng) {
msg <- glue::glue(
"Available engines are: ",
glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ")
)
rlang::abort(msg)
}

check_engine <- function(object) {
avail_eng <- possible_engines(object)
eng <- object$engine
if (is.null(eng) || length(eng) > 1) {
stop_incompatible_engine(avail_eng)
} else if (!(eng %in% avail_eng)) {
stop_incompatible_engine(avail_eng)
}
object
}

# ------------------------------------------------------------------------------

shhhh <- function(x)
Expand Down Expand Up @@ -90,16 +71,16 @@ load_libs <- function(x, quiet, attach = FALSE) {
#' translate(mod, engine = "glmnet")
#' @export
set_engine <- function(object, engine, ...) {
mod_type <- class(object)[1]
if (!inherits(object, "model_spec")) {
rlang::abort("`object` should have class 'model_spec'.")
}

if (rlang::is_missing(engine)) {
avail_eng <- possible_engines(object)
stop_incompatible_engine(avail_eng)
stop_missing_engine(mod_type)
}
object$engine <- engine
object <- check_engine(object)
check_spec_mode_engine_val(mod_type, object$engine, object$mode)

if (object$engine == "liquidSVM") {
lifecycle::deprecate_soft(
Expand All @@ -109,7 +90,7 @@ set_engine <- function(object, engine, ...) {
}

new_model_spec(
cls = class(object)[1],
cls = mod_type,
args = object$args,
eng_args = enquos(...),
mode = object$mode,
Expand Down
5 changes: 1 addition & 4 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ check_empty_ellipse <- function (...) {
terms
}

all_modes <- c("classification", "regression", "censored regression")


deparserizer <- function(x, limit = options()$width - 10) {
x <- deparse(x, width.cutoff = limit)
x <- gsub("^ ", "", x)
Expand Down Expand Up @@ -192,7 +189,7 @@ update_dot_check <- function(...) {
#' @rdname add_on_exports
new_model_spec <- function(cls, args, eng_args, mode, method, engine) {

check_spec_mode_val(cls, mode)
check_spec_mode_engine_val(cls, engine, mode)

out <- list(args = args, eng_args = eng_args,
mode = mode, method = method, engine = engine)
Expand Down
9 changes: 5 additions & 4 deletions R/translate.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,15 @@ translate.default <- function(x, engine = x$engine, ...) {
mod_name <- specific_model(x)

x$engine <- engine
x <- check_engine(x)

if (x$mode == "unknown") {
rlang::abort("Model code depends on the mode; please specify one.")
}

if (is.null(x$method))
check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)

if (is.null(x$method)) {
x$method <- get_model_spec(mod_name, x$mode, engine)
}

arg_key <- get_args(mod_name, engine)

Expand Down Expand Up @@ -174,7 +175,7 @@ deharmonize <- function(args, key) {

add_methods <- function(x, engine) {
x$engine <- engine
x <- check_engine(x)
check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
x$method <- get_model_spec(specific_model(x), x$mode, x$engine)
x
}
2 changes: 1 addition & 1 deletion man/details_gen_additive_mod_mgcv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/extract-parsnip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/rmd/boost_tree_C5.0.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defaults <-
param <-
boost_tree() %>%
set_engine("C5.0") %>%
set_mode("regression") %>%
set_mode("classification") %>%
tunable() %>%
dplyr::select(-source, -component, -component_id, parsnip = name) %>%
dplyr::mutate(
Expand Down
2 changes: 1 addition & 1 deletion man/rmd/decision_tree_C5.0.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defaults <-
param <-
decision_tree() %>%
set_engine("C5.0") %>%
set_mode("regression") %>%
set_mode("classification") %>%
tunable() %>%
dplyr::select(-source, -component, -component_id, parsnip = name) %>%
dplyr::mutate(
Expand Down
Loading