Skip to content

Harmonize errors for set_mode() and set_engine() #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@

* The xgboost engine for boosted trees was translating `mtry` to xgboost's `colsample_bytree`. We now map `mtry` to `colsample_bynode` since that is more consistent with how random forest works. `colsample_bytree` can still be optimized by passing it in as an engine argument. `colsample_bynode` was added to xgboost after the `parsnip` package code was written. (#495)

* For xgboost boosting, `mtry` and `colsample_bytree` can be passed as integer counts or proportions while `subsample` and `validation` should be proportions. `xgb_train()` now has a new option `counts` for state what scale `mtry` and `colsample_bytree` are being used. (#461)
* For xgboost, `mtry` and `colsample_bytree` can be passed as integer counts or proportions, while `subsample` and `validation` should always be proportions. `xgb_train()` now has a new option `counts` (`TRUE` or `FALSE`) that states which scale for `mtry` and `colsample_bytree` is being used. (#461)

## Other Changes

* Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462).

* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467).
* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467). Both `set_mode()` and `set_engine()` now error for `NULL` or missing arguments (#503).

* Re-organized model documentation for `update` methods (#479).

Expand Down
22 changes: 12 additions & 10 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,23 @@ check_mode_val <- function(mode) {
invisible(NULL)
}


stop_incompatible_mode <- function(spec_modes) {
msg <- glue::glue(
"Available modes are: ",
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
)
rlang::abort(msg)
}

# check if class and mode are compatible
check_spec_mode_val <- function(cls, mode) {
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
compatible_modes <-
glue::glue(
"`mode` should be one of: ",
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
)

if (is.null(mode)) {
rlang::abort(compatible_modes)
if (is.null(mode) || length(mode) > 1) {
stop_incompatible_mode(spec_modes)
} else if (!(mode %in% spec_modes)) {
rlang::abort(compatible_modes)
stop_incompatible_mode(spec_modes)
}

invisible(NULL)
}

Expand Down
7 changes: 4 additions & 3 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ set_args <- function(object, ...) {
#' @rdname set_args
#' @export
set_mode <- function(object, mode) {
cls <- class(object)[1]
if (rlang::is_missing(mode)) {
mode <- NULL
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
stop_incompatible_mode(spec_modes)
}
mode <- mode[1]
check_spec_mode_val(class(object)[1], mode)
check_spec_mode_val(cls, mode)
object$mode <- mode
object
}
Expand Down
31 changes: 15 additions & 16 deletions R/engines.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,21 @@ possible_engines <- function(object, ...) {
unique(engs$engine)
}

stop_incompatible_engine <- function(avail_eng) {
msg <- glue::glue(
"Available engines are: ",
glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ")
)
rlang::abort(msg)
}

check_engine <- function(object) {
avail_eng <- possible_engines(object)
if (is.null(object$engine)) {
object$engine <- avail_eng[1]
rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`"))
} else {
if (!is.character(object$engine) | length(object$engine) != 1) {
rlang::abort("`engine` should be a single character value.")
}
}
if (!(object$engine %in% avail_eng)) {
rlang::abort(
glue::glue(
"Engine '{object$engine}' is not available. Please use one of: ",
glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ")
)
)
eng <- object$engine
if (is.null(eng) || length(eng) > 1) {
stop_incompatible_engine(avail_eng)
} else if (!(eng %in% avail_eng)) {
stop_incompatible_engine(avail_eng)
}
object
}
Expand Down Expand Up @@ -97,7 +95,8 @@ set_engine <- function(object, engine, ...) {
}

if (rlang::is_missing(engine)) {
engine <- NULL
avail_eng <- possible_engines(object)
stop_incompatible_engine(avail_eng)
}
object$engine <- engine
object <- check_engine(object)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_args_and_modes.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ test_that('pipe engine', {
test_that("can't set a mode that isn't allowed by the model spec", {
expect_error(
set_mode(linear_reg(), "classification"),
"`mode` should be one of"
"Available modes are:"
)
})
2 changes: 1 addition & 1 deletion tests/testthat/test_mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ test_that('updating', {
})

test_that('bad input', {
expect_warning(translate(mars(mode = "regression") %>% set_engine()))
expect_error(translate(mars(mode = "regression") %>% set_engine()))
expect_error(translate(mars() %>% set_engine("wat?")))
expect_error(translate(mars(formula = y ~ x)))
})
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_multinom_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,6 @@ test_that('updating', {
test_that('bad input', {
expect_error(multinom_reg(mode = "regression"))
expect_error(translate(multinom_reg(penalty = 0.1) %>% set_engine("wat?")))
expect_warning(multinom_reg(penalty = 0.1) %>% set_engine())
expect_error(multinom_reg(penalty = 0.1) %>% set_engine())
expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
})
2 changes: 1 addition & 1 deletion tests/testthat/test_nearest_neighbor.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,5 @@ test_that('updating', {

test_that('bad input', {
expect_error(nearest_neighbor(mode = "reallyunknown"))
expect_warning(nearest_neighbor() %>% set_engine( NULL))
expect_error(nearest_neighbor() %>% set_engine( NULL))
})
2 changes: 1 addition & 1 deletion tests/testthat/test_nullmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ test_that('engine arguments', {
})

test_that('bad input', {
expect_warning(translate(null_model(mode = "regression") %>% set_engine()))
expect_error(translate(null_model(mode = "regression") %>% set_engine()))
expect_error(translate(null_model() %>% set_engine("wat?")))
expect_error(translate(null_model(formula = y ~ x)))
expect_warning(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_rand_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ test_that('updating', {
})

test_that('bad input', {
expect_warning(translate(rand_forest(mode = "classification") %>% set_engine(NULL)))
expect_error(translate(rand_forest(mode = "classification") %>% set_engine(NULL)))
expect_error(rand_forest(mode = "time series"))
expect_error(translate(rand_forest(mode = "classification") %>% set_engine("wat?")))
expect_error(translate(rand_forest(mode = "classification", ytest = 2)))
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_surv_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ test_that('bad input', {

expect_error(surv_reg(mode = ", classification"))
expect_error(translate(surv_reg() %>% set_engine("wat")))
expect_warning(translate(surv_reg() %>% set_engine(NULL)))
expect_error(translate(surv_reg() %>% set_engine(NULL)))
})

test_that("deprecation warning", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_svm_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ test_that('updating', {
})

test_that('bad input', {
expect_warning(translate(svm_linear(mode = "regression") %>% set_engine( NULL)))
expect_error(translate(svm_linear(mode = "regression") %>% set_engine( NULL)))
expect_error(svm_linear(mode = "reallyunknown"))
expect_error(translate(svm_linear(mode = "regression") %>% set_engine("LiblineaR", type = 3)))
expect_error(translate(svm_linear(mode = "classification") %>% set_engine("LiblineaR", type = 11)))
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_svm_liquidsvm.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@ test_that('updating', {

test_that('bad input', {
expect_error(svm_rbf(mode = "reallyunknown"))
expect_warning(svm_rbf() %>% set_engine( NULL))
expect_error(svm_rbf() %>% set_engine( NULL))
})
2 changes: 1 addition & 1 deletion tests/testthat/test_svm_poly.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ test_that('updating', {

test_that('bad input', {
expect_error(svm_poly(mode = "reallyunknown"))
expect_warning(svm_poly() %>% set_engine(NULL))
expect_error(svm_poly() %>% set_engine(NULL))
})

# ------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_svm_rbf.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test_that('updating', {

test_that('bad input', {
expect_error(svm_rbf(mode = "reallyunknown"))
expect_warning(translate(svm_rbf(mode = "regression") %>% set_engine( NULL)))
expect_error(translate(svm_rbf(mode = "regression") %>% set_engine( NULL)))
})

# ------------------------------------------------------------------------------
Expand Down