From 4eba1b7e8b25e86b3a29b6c4fd937fc7ffa54c19 Mon Sep 17 00:00:00 2001 From: jtlandis Date: Wed, 14 Apr 2021 13:47:36 -0400 Subject: [PATCH 1/8] New function to check model_spec and mode compatibility --- R/aaa_models.R | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/R/aaa_models.R b/R/aaa_models.R index 1cbce5bc9..237bcde6b 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -132,6 +132,17 @@ check_mode_val <- function(mode) { invisible(NULL) } +# check if class and mode are compatible +check_spec_mode_val <- function(cls, mode) { + spec_modes <- rlang::env_get(get_model_env(), paste0(cls, + "_modes")) + if (!(mode %in% spec_modes)) + rlang::abort(glue::glue("`mode` should be one of: ", + glue::glue_collapse(glue::glue("'{spec_modes}'"), + sep = ", "))) + invisible(NULL) +} + check_engine_val <- function(eng) { if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).") From fdb0e19e350fd4988c94f144ac9a252496bb5c95 Mon Sep 17 00:00:00 2001 From: jtlandis Date: Wed, 14 Apr 2021 13:49:14 -0400 Subject: [PATCH 2/8] reformat to style guidelines --- R/aaa_models.R | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 237bcde6b..112d5b225 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -134,12 +134,14 @@ check_mode_val <- function(mode) { # check if class and mode are compatible check_spec_mode_val <- function(cls, mode) { - spec_modes <- rlang::env_get(get_model_env(), paste0(cls, - "_modes")) + spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) if (!(mode %in% spec_modes)) - rlang::abort(glue::glue("`mode` should be one of: ", - glue::glue_collapse(glue::glue("'{spec_modes}'"), - sep = ", "))) + rlang::abort( + glue::glue( + "`mode` should be one of: ", + glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") + ) + ) invisible(NULL) } From 88350c84a4ba0f6872cda01a2f02403dbfcf6ece Mon Sep 17 00:00:00 2001 From: jtlandis Date: Wed, 14 Apr 2021 13:51:30 -0400 Subject: [PATCH 3/8] refactoring - new function check_spec_mode_val is a drop and replace for this chunk --- R/misc.R | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/R/misc.R b/R/misc.R index ba4e249f8..5d1fa9876 100644 --- a/R/misc.R +++ b/R/misc.R @@ -191,14 +191,8 @@ update_dot_check <- function(...) { #' @keywords internal #' @rdname add_on_exports new_model_spec <- function(cls, args, eng_args, mode, method, engine) { - spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) - if (!(mode %in% spec_modes)) - rlang::abort( - glue::glue( - "`mode` should be one of: ", - glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") - ) - ) + + check_spec_mode_val(cls, mode) out <- list(args = args, eng_args = eng_args, mode = mode, method = method, engine = engine) From 1a27482e671fc4dacf005e32e92d3c1c108a89f0 Mon Sep 17 00:00:00 2001 From: jtlandis Date: Wed, 14 Apr 2021 15:00:54 -0400 Subject: [PATCH 4/8] include new function in `set_mode` --- R/arguments.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/arguments.R b/R/arguments.R index 9bbb15999..cc40694f8 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -87,6 +87,10 @@ set_mode <- function(object, mode) { ) ) } + #only check if object is a model_spec + if(inherits(object, "model_spec")) { + check_spec_mode_val(class(object)[1], mode) + } object$mode <- mode object } From 31ece9fb909eec5e914ab724c7b529f67884e396 Mon Sep 17 00:00:00 2001 From: jtlandis Date: Wed, 14 Apr 2021 15:59:29 -0400 Subject: [PATCH 5/8] adding tests to confirm expected `set_modes` works with the base model_spec objects of parsnip. Includes at least one expect_error per model_spec. --- tests/testthat/test_spec_and_modes.R | 168 +++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 tests/testthat/test_spec_and_modes.R diff --git a/tests/testthat/test_spec_and_modes.R b/tests/testthat/test_spec_and_modes.R new file mode 100644 index 000000000..7173cf197 --- /dev/null +++ b/tests/testthat/test_spec_and_modes.R @@ -0,0 +1,168 @@ +library(testthat) +library(dplyr) +library(parsnip) + +context("setting modes works as intended") + +test_that("correct modes of boost_tree",{ + basic <- boost_tree() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of decision_tree",{ + basic <- decision_tree() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of linear_reg", { + basic <- linear_reg() + expect_equal(basic$mode, "regression") + expect_error(linear_reg("classification")) + expect_error(basic %>% set_mode("classification")) +}) + +test_that("correct modes of logistic_reg", { + basic <- logistic_reg() + expect_equal(basic$mode, "classification") + expect_error(logistic_reg("regression")) + expect_error(basic %>% set_mode("regression")) +}) + +test_that("correct modes of mars",{ + basic <- mars() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of mlp",{ + basic <- mlp() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of multinom_reg", { + basic <- multinom_reg() + expect_equal(basic$mode, "classification") + expect_error(multinom_reg("regression")) + expect_error(basic %>% set_mode("regression")) +}) + +test_that("correct modes of nearest_neighbor",{ + basic <- nearest_neighbor() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of null_model",{ + basic <- null_model() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of rand_forest",{ + basic <- rand_forest() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of survival_reg", { + basic <- survival_reg() + expect_equal(basic$mode, "censored regression") + expect_error(surv_reg("classification")) + expect_error(basic %>% set_mode("classification")) +}) + +test_that("correct modes of svm_poly",{ + basic <- svm_poly() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) + +test_that("correct modes of svm_rbf",{ + basic <- svm_rbf() + basic_expect <- basic + #default + expect_equal(basic$mode, "unknown") + #set classification + basic_expect$mode <- "classification" + expect_equal(basic %>% set_mode("classification"), basic_expect) + #set regression + basic_expect$mode <- "regression" + expect_equal(basic %>% set_mode("regression"), basic_expect) + #attempt to set incorrect + expect_error(basic %>% set_mode("censored regression")) +}) From 311996fe646318f0546f0d776bee71b9fecde234 Mon Sep 17 00:00:00 2001 From: jtlandis Date: Wed, 14 Apr 2021 16:06:48 -0400 Subject: [PATCH 6/8] small cleanup on null_model and surv_reg --- tests/testthat/test_spec_and_modes.R | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/testthat/test_spec_and_modes.R b/tests/testthat/test_spec_and_modes.R index 7173cf197..05541b359 100644 --- a/tests/testthat/test_spec_and_modes.R +++ b/tests/testthat/test_spec_and_modes.R @@ -104,10 +104,7 @@ test_that("correct modes of null_model",{ basic <- null_model() basic_expect <- basic #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) + expect_equal(basic$mode, "classification") #set regression basic_expect$mode <- "regression" expect_equal(basic %>% set_mode("regression"), basic_expect) @@ -133,7 +130,7 @@ test_that("correct modes of rand_forest",{ test_that("correct modes of survival_reg", { basic <- survival_reg() expect_equal(basic$mode, "censored regression") - expect_error(surv_reg("classification")) + expect_error(survival_reg("classification")) expect_error(basic %>% set_mode("classification")) }) From fd765687b730f6d72f9c2e8646a786dd0473bfc7 Mon Sep 17 00:00:00 2001 From: jtlandis Date: Wed, 14 Apr 2021 16:13:08 -0400 Subject: [PATCH 7/8] adding description of changes to NEWS.md --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index d908a1723..37fd4a7d5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,6 +12,8 @@ * Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462). +* `set_mode` now checks if mode is compatible with model class, similar to `new_model_spec` (jtlandis #467) + # parsnip 0.1.5 * An RStudio add-in is available that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`. From 40b5115d72b0b9bc78be56de8471792c2f275e35 Mon Sep 17 00:00:00 2001 From: jtlandis Date: Fri, 30 Apr 2021 09:47:03 -0400 Subject: [PATCH 8/8] commiting changes suggested by DavisVaughan --- NEWS.md | 2 +- R/arguments.R | 5 +- tests/testthat/test_args_and_modes.R | 7 ++ tests/testthat/test_spec_and_modes.R | 165 --------------------------- 4 files changed, 9 insertions(+), 170 deletions(-) delete mode 100644 tests/testthat/test_spec_and_modes.R diff --git a/NEWS.md b/NEWS.md index 37fd4a7d5..a332220d3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,7 +12,7 @@ * Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462). -* `set_mode` now checks if mode is compatible with model class, similar to `new_model_spec` (jtlandis #467) +* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467). # parsnip 0.1.5 diff --git a/R/arguments.R b/R/arguments.R index cc40694f8..3a4d887b3 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -87,10 +87,7 @@ set_mode <- function(object, mode) { ) ) } - #only check if object is a model_spec - if(inherits(object, "model_spec")) { - check_spec_mode_val(class(object)[1], mode) - } + check_spec_mode_val(class(object)[1], mode) object$mode <- mode object } diff --git a/tests/testthat/test_args_and_modes.R b/tests/testthat/test_args_and_modes.R index 2a0b86e63..9abf5807e 100644 --- a/tests/testthat/test_args_and_modes.R +++ b/tests/testthat/test_args_and_modes.R @@ -45,3 +45,10 @@ test_that('pipe engine', { expect_error(rand_forest() %>% set_mode(2)) expect_error(rand_forest() %>% set_mode("haberdashery")) }) + +test_that("can't set a mode that isn't allowed by the model spec", { + expect_error( + set_mode(linear_reg(), "classification"), + "`mode` should be one of" + ) +}) diff --git a/tests/testthat/test_spec_and_modes.R b/tests/testthat/test_spec_and_modes.R deleted file mode 100644 index 05541b359..000000000 --- a/tests/testthat/test_spec_and_modes.R +++ /dev/null @@ -1,165 +0,0 @@ -library(testthat) -library(dplyr) -library(parsnip) - -context("setting modes works as intended") - -test_that("correct modes of boost_tree",{ - basic <- boost_tree() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of decision_tree",{ - basic <- decision_tree() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of linear_reg", { - basic <- linear_reg() - expect_equal(basic$mode, "regression") - expect_error(linear_reg("classification")) - expect_error(basic %>% set_mode("classification")) -}) - -test_that("correct modes of logistic_reg", { - basic <- logistic_reg() - expect_equal(basic$mode, "classification") - expect_error(logistic_reg("regression")) - expect_error(basic %>% set_mode("regression")) -}) - -test_that("correct modes of mars",{ - basic <- mars() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of mlp",{ - basic <- mlp() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of multinom_reg", { - basic <- multinom_reg() - expect_equal(basic$mode, "classification") - expect_error(multinom_reg("regression")) - expect_error(basic %>% set_mode("regression")) -}) - -test_that("correct modes of nearest_neighbor",{ - basic <- nearest_neighbor() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of null_model",{ - basic <- null_model() - basic_expect <- basic - #default - expect_equal(basic$mode, "classification") - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of rand_forest",{ - basic <- rand_forest() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of survival_reg", { - basic <- survival_reg() - expect_equal(basic$mode, "censored regression") - expect_error(survival_reg("classification")) - expect_error(basic %>% set_mode("classification")) -}) - -test_that("correct modes of svm_poly",{ - basic <- svm_poly() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -}) - -test_that("correct modes of svm_rbf",{ - basic <- svm_rbf() - basic_expect <- basic - #default - expect_equal(basic$mode, "unknown") - #set classification - basic_expect$mode <- "classification" - expect_equal(basic %>% set_mode("classification"), basic_expect) - #set regression - basic_expect$mode <- "regression" - expect_equal(basic %>% set_mode("regression"), basic_expect) - #attempt to set incorrect - expect_error(basic %>% set_mode("censored regression")) -})