diff --git a/NEWS.md b/NEWS.md index 9e115794b..41736454f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -13,13 +13,13 @@ * The xgboost engine for boosted trees was translating `mtry` to xgboost's `colsample_bytree`. We now map `mtry` to `colsample_bynode` since that is more consistent with how random forest works. `colsample_bytree` can still be optimized by passing it in as an engine argument. `colsample_bynode` was added to xgboost after the `parsnip` package code was written. (#495) -* For xgboost boosting, `mtry` and `colsample_bytree` can be passed as integer counts or proportions while `subsample` and `validation` should be proportions. `xgb_train()` now has a new option `counts` for state what scale `mtry` and `colsample_bytree` are being used. (#461) +* For xgboost, `mtry` and `colsample_bytree` can be passed as integer counts or proportions, while `subsample` and `validation` should always be proportions. `xgb_train()` now has a new option `counts` (`TRUE` or `FALSE`) that states which scale for `mtry` and `colsample_bytree` is being used. (#461) ## Other Changes * Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462). -* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467). +* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467). Both `set_mode()` and `set_engine()` now error for `NULL` or missing arguments (#503). * Re-organized model documentation for `update` methods (#479). diff --git a/R/aaa_models.R b/R/aaa_models.R index 2a40e3d11..f7f1b3824 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -133,21 +133,23 @@ check_mode_val <- function(mode) { invisible(NULL) } + +stop_incompatible_mode <- function(spec_modes) { + msg <- glue::glue( + "Available modes are: ", + glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") + ) + rlang::abort(msg) +} + # check if class and mode are compatible check_spec_mode_val <- function(cls, mode) { spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) - compatible_modes <- - glue::glue( - "`mode` should be one of: ", - glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") - ) - - if (is.null(mode)) { - rlang::abort(compatible_modes) + if (is.null(mode) || length(mode) > 1) { + stop_incompatible_mode(spec_modes) } else if (!(mode %in% spec_modes)) { - rlang::abort(compatible_modes) + stop_incompatible_mode(spec_modes) } - invisible(NULL) } diff --git a/R/arguments.R b/R/arguments.R index ed2b995a7..8c7cb735b 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -76,11 +76,12 @@ set_args <- function(object, ...) { #' @rdname set_args #' @export set_mode <- function(object, mode) { + cls <- class(object)[1] if (rlang::is_missing(mode)) { - mode <- NULL + spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) + stop_incompatible_mode(spec_modes) } - mode <- mode[1] - check_spec_mode_val(class(object)[1], mode) + check_spec_mode_val(cls, mode) object$mode <- mode object } diff --git a/R/engines.R b/R/engines.R index 1babb0186..e6000c184 100644 --- a/R/engines.R +++ b/R/engines.R @@ -10,23 +10,21 @@ possible_engines <- function(object, ...) { unique(engs$engine) } +stop_incompatible_engine <- function(avail_eng) { + msg <- glue::glue( + "Available engines are: ", + glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ") + ) + rlang::abort(msg) +} + check_engine <- function(object) { avail_eng <- possible_engines(object) - if (is.null(object$engine)) { - object$engine <- avail_eng[1] - rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`")) - } else { - if (!is.character(object$engine) | length(object$engine) != 1) { - rlang::abort("`engine` should be a single character value.") - } - } - if (!(object$engine %in% avail_eng)) { - rlang::abort( - glue::glue( - "Engine '{object$engine}' is not available. Please use one of: ", - glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ") - ) - ) + eng <- object$engine + if (is.null(eng) || length(eng) > 1) { + stop_incompatible_engine(avail_eng) + } else if (!(eng %in% avail_eng)) { + stop_incompatible_engine(avail_eng) } object } @@ -97,7 +95,8 @@ set_engine <- function(object, engine, ...) { } if (rlang::is_missing(engine)) { - engine <- NULL + avail_eng <- possible_engines(object) + stop_incompatible_engine(avail_eng) } object$engine <- engine object <- check_engine(object) diff --git a/tests/testthat/test_args_and_modes.R b/tests/testthat/test_args_and_modes.R index 9abf5807e..be58c7933 100644 --- a/tests/testthat/test_args_and_modes.R +++ b/tests/testthat/test_args_and_modes.R @@ -49,6 +49,6 @@ test_that('pipe engine', { test_that("can't set a mode that isn't allowed by the model spec", { expect_error( set_mode(linear_reg(), "classification"), - "`mode` should be one of" + "Available modes are:" ) }) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index ca5b3e650..43c795c78 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -111,7 +111,7 @@ test_that('updating', { }) test_that('bad input', { - expect_warning(translate(mars(mode = "regression") %>% set_engine())) + expect_error(translate(mars(mode = "regression") %>% set_engine())) expect_error(translate(mars() %>% set_engine("wat?"))) expect_error(translate(mars(formula = y ~ x))) }) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 3bbcea84a..9a229f496 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -123,6 +123,6 @@ test_that('updating', { test_that('bad input', { expect_error(multinom_reg(mode = "regression")) expect_error(translate(multinom_reg(penalty = 0.1) %>% set_engine("wat?"))) - expect_warning(multinom_reg(penalty = 0.1) %>% set_engine()) + expect_error(multinom_reg(penalty = 0.1) %>% set_engine()) expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) }) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index a550fa4ba..2ae8a671c 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -122,5 +122,5 @@ test_that('updating', { test_that('bad input', { expect_error(nearest_neighbor(mode = "reallyunknown")) - expect_warning(nearest_neighbor() %>% set_engine( NULL)) + expect_error(nearest_neighbor() %>% set_engine( NULL)) }) diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index 9a5af12bf..5590e2d5d 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -32,7 +32,7 @@ test_that('engine arguments', { }) test_that('bad input', { - expect_warning(translate(null_model(mode = "regression") %>% set_engine())) + expect_error(translate(null_model(mode = "regression") %>% set_engine())) expect_error(translate(null_model() %>% set_engine("wat?"))) expect_error(translate(null_model(formula = y ~ x))) expect_warning( diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index b8dbf361f..926a8b899 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -192,7 +192,7 @@ test_that('updating', { }) test_that('bad input', { - expect_warning(translate(rand_forest(mode = "classification") %>% set_engine(NULL))) + expect_error(translate(rand_forest(mode = "classification") %>% set_engine(NULL))) expect_error(rand_forest(mode = "time series")) expect_error(translate(rand_forest(mode = "classification") %>% set_engine("wat?"))) expect_error(translate(rand_forest(mode = "classification", ytest = 2))) diff --git a/tests/testthat/test_surv_reg.R b/tests/testthat/test_surv_reg.R index 1947bed41..424f48a9b 100644 --- a/tests/testthat/test_surv_reg.R +++ b/tests/testthat/test_surv_reg.R @@ -85,7 +85,7 @@ test_that('bad input', { expect_error(surv_reg(mode = ", classification")) expect_error(translate(surv_reg() %>% set_engine("wat"))) - expect_warning(translate(surv_reg() %>% set_engine(NULL))) + expect_error(translate(surv_reg() %>% set_engine(NULL))) }) test_that("deprecation warning", { diff --git a/tests/testthat/test_svm_linear.R b/tests/testthat/test_svm_linear.R index 82b3e7a42..e53540980 100644 --- a/tests/testthat/test_svm_linear.R +++ b/tests/testthat/test_svm_linear.R @@ -104,7 +104,7 @@ test_that('updating', { }) test_that('bad input', { - expect_warning(translate(svm_linear(mode = "regression") %>% set_engine( NULL))) + expect_error(translate(svm_linear(mode = "regression") %>% set_engine( NULL))) expect_error(svm_linear(mode = "reallyunknown")) expect_error(translate(svm_linear(mode = "regression") %>% set_engine("LiblineaR", type = 3))) expect_error(translate(svm_linear(mode = "classification") %>% set_engine("LiblineaR", type = 11))) diff --git a/tests/testthat/test_svm_liquidsvm.R b/tests/testthat/test_svm_liquidsvm.R index 2a966b8e9..cef99ec9d 100644 --- a/tests/testthat/test_svm_liquidsvm.R +++ b/tests/testthat/test_svm_liquidsvm.R @@ -77,5 +77,5 @@ test_that('updating', { test_that('bad input', { expect_error(svm_rbf(mode = "reallyunknown")) - expect_warning(svm_rbf() %>% set_engine( NULL)) + expect_error(svm_rbf() %>% set_engine( NULL)) }) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index fd403560f..bea7d5093 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -106,7 +106,7 @@ test_that('updating', { test_that('bad input', { expect_error(svm_poly(mode = "reallyunknown")) - expect_warning(svm_poly() %>% set_engine(NULL)) + expect_error(svm_poly() %>% set_engine(NULL)) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index 9541f1792..059754fbf 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -87,7 +87,7 @@ test_that('updating', { test_that('bad input', { expect_error(svm_rbf(mode = "reallyunknown")) - expect_warning(translate(svm_rbf(mode = "regression") %>% set_engine( NULL))) + expect_error(translate(svm_rbf(mode = "regression") %>% set_engine( NULL))) }) # ------------------------------------------------------------------------------