From 83c5900d8b2ce8ecd3cbe137df19c804c8e8b856 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Thu, 6 May 2021 14:06:56 -0600 Subject: [PATCH 1/7] Error for glmnet models is there is not exactly one value for penalty --- R/linear_reg.R | 7 +++++++ R/logistic_reg.R | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/R/linear_reg.R b/R/linear_reg.R index f3137ef19..63e7c7872 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -112,6 +112,13 @@ translate.linear_reg <- function(x, engine = x$engine, ...) { # Since the `fit` information is gone for the penalty, we need to have an # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) + if (length(x$args$penalty) != 1) { + rlang::abort(c( + "For the glmnet engine, `penalty` must be a single number.", + glue::glue("There are {length(x$args$penalty)} values for `penalty`."), + "To try multiple values for total regularization, use the tune package." + )) + } } x diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 4d7b2d3ab..d3b0611d3 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -115,6 +115,13 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) { # Since the `fit` information is gone for the penalty, we need to have an # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) + if (length(x$args$penalty) != 1) { + rlang::abort(c( + "For the glmnet engine, `penalty` must be a single number.", + glue::glue("There are {length(x$args$penalty)} values for `penalty`."), + "To try multiple values for total regularization, use the tune package." + )) + } } if (engine == "LiblineaR") { From b8f02a3665a78946246e9385ad4eacb082998e9c Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Thu, 6 May 2021 14:10:37 -0600 Subject: [PATCH 2/7] Updates tests for new glmnet error --- tests/testthat/test_linear_reg.R | 43 +++++++++--------------------- tests/testthat/test_logistic_reg.R | 43 +++++++++--------------------- tests/testthat/test_multinom_reg.R | 20 +++++--------- 3 files changed, 33 insertions(+), 73 deletions(-) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 4c0617cec..b839572e8 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -15,7 +15,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)] test_that('primary arguments', { basic <- linear_reg() basic_lm <- translate(basic %>% set_engine("lm")) - basic_glmnet <- translate(basic %>% set_engine("glmnet")) + expect_error( + basic_glmnet <- translate(basic %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" + ) basic_stan <- translate(basic %>% set_engine("stan")) basic_spark <- translate(basic %>% set_engine("spark")) expect_equal(basic_lm$method$fit$args, @@ -25,14 +28,6 @@ test_that('primary arguments', { weights = expr(missing_arg()) ) ) - expect_equal(basic_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - family = "gaussian" - ) - ) expect_equal(basic_stan$method$fit$args, list( formula = expr(missing_arg()), @@ -51,17 +46,11 @@ test_that('primary arguments', { ) mixture <- linear_reg(mixture = 0.128) - mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) - mixture_spark <- translate(mixture %>% set_engine("spark")) - expect_equal(mixture_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(0.128), - family = "gaussian" - ) + expect_error( + mixture_glmnet <- translate(mixture %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) + mixture_spark <- translate(mixture %>% set_engine("spark")) expect_equal(mixture_spark$method$fit$args, list( x = expr(missing_arg()), @@ -92,17 +81,11 @@ test_that('primary arguments', { ) mixture_v <- linear_reg(mixture = varying()) - mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) - mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) - expect_equal(mixture_v_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(varying()), - family = "gaussian" - ) + expect_error( + mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) + mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) expect_equal(mixture_v_spark$method$fit$args, list( x = expr(missing_arg()), @@ -125,7 +108,7 @@ test_that('engine arguments', { ) ) - glmnet_nlam <- linear_reg() %>% set_engine("glmnet", nlambda = 10) + glmnet_nlam <- linear_reg(penalty = 0.1) %>% set_engine("glmnet", nlambda = 10) expect_equal(translate(glmnet_nlam)$method$fit$args, list( x = expr(missing_arg()), diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index b19806dae..5b2c9df1b 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -16,7 +16,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)] test_that('primary arguments', { basic <- logistic_reg() basic_glm <- translate(basic %>% set_engine("glm")) - basic_glmnet <- translate(basic %>% set_engine("glmnet")) + expect_error( + basic_glmnet <- translate(basic %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" + ) basic_liblinear <- translate(basic %>% set_engine("LiblineaR")) basic_stan <- translate(basic %>% set_engine("stan")) basic_spark <- translate(basic %>% set_engine("spark")) @@ -28,14 +31,6 @@ test_that('primary arguments', { family = expr(stats::binomial) ) ) - expect_equal(basic_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - family = "binomial" - ) - ) expect_equal(basic_liblinear$method$fit$args, list( x = expr(missing_arg()), @@ -63,17 +58,11 @@ test_that('primary arguments', { ) mixture <- logistic_reg(mixture = 0.128) - mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) - mixture_spark <- translate(mixture %>% set_engine("spark")) - expect_equal(mixture_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(0.128), - family = "binomial" - ) + expect_error( + mixture_glmnet <- translate(mixture %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) + mixture_spark <- translate(mixture %>% set_engine("spark")) expect_equal(mixture_spark$method$fit$args, list( x = expr(missing_arg()), @@ -116,18 +105,12 @@ test_that('primary arguments', { ) mixture_v <- logistic_reg(mixture = varying()) - mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) + expect_error( + mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" + ) mixture_v_liblinear <- translate(mixture_v %>% set_engine("LiblineaR")) mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) - expect_equal(mixture_v_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(varying()), - family = "binomial" - ) - ) expect_equal(mixture_v_liblinear$method$fit$args, list( x = expr(missing_arg()), @@ -194,7 +177,7 @@ test_that('engine arguments', { ) ) - glmnet_nlam <- logistic_reg() + glmnet_nlam <- logistic_reg(penalty = 0.1) expect_equal( translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args, list( diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index a373159b6..6a1b0037d 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -13,17 +13,11 @@ hpc <- hpc_data[1:150, c(2:5, 8)] test_that('primary arguments', { basic <- multinom_reg() - basic_glmnet <- translate(basic %>% set_engine("glmnet")) - expect_equal(basic_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - family = "multinomial" - ) + expect_error( + basic_glmnet <- translate(basic %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) - - mixture <- multinom_reg(mixture = 0.128) + mixture <- multinom_reg(penalty = 0.1, mixture = 0.128) mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) expect_equal(mixture_glmnet$method$fit$args, list( @@ -46,7 +40,7 @@ test_that('primary arguments', { ) ) - mixture_v <- multinom_reg(mixture = varying()) + mixture_v <- multinom_reg(penalty = 0.01, mixture = varying()) mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) expect_equal(mixture_v_glmnet$method$fit$args, list( @@ -61,7 +55,7 @@ test_that('primary arguments', { }) test_that('engine arguments', { - glmnet_nlam <- multinom_reg() + glmnet_nlam <- multinom_reg(penalty = 0.01) expect_equal( translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args, list( @@ -117,5 +111,5 @@ test_that('bad input', { expect_error(multinom_reg(mode = "regression")) expect_error(translate(multinom_reg() %>% set_engine("wat?"))) expect_error(translate(multinom_reg() %>% set_engine())) - expect_warning(translate(multinom_reg() %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) + expect_warning(translate(multinom_reg(penalty = 0.01) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) }) From ec16cc932dcbae08159b25b26326a5bc5194f05d Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Thu, 6 May 2021 19:51:21 -0600 Subject: [PATCH 3/7] Update examples + more docs to accomodate new glmnet error --- R/engines.R | 2 +- man/linear_reg.Rd | 5 ++++- man/logistic_reg.Rd | 5 ++++- man/multinom_reg.Rd | 5 ++++- man/rmd/linear-reg.Rmd | 2 +- man/rmd/logistic-reg.Rmd | 2 +- man/rmd/multinom-reg.Rmd | 2 +- man/set_engine.Rd | 2 +- 8 files changed, 17 insertions(+), 8 deletions(-) diff --git a/R/engines.R b/R/engines.R index d96d3e044..913cf8a1e 100644 --- a/R/engines.R +++ b/R/engines.R @@ -82,7 +82,7 @@ load_libs <- function(x, quiet, attach = FALSE) { #' @examples #' # First, set general arguments using the standardized names #' mod <- -#' logistic_reg(mixture = 1/3) %>% +#' logistic_reg(penalty = 0.01, mixture = 1/3) %>% #' # now say how you want to fit the model and another other options #' set_engine("glmnet", nlambda = 10) #' translate(mod, engine = "glmnet") diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index 53057d647..388a57325 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -86,12 +86,15 @@ call. For this type of model, the template of the fit calls are below. } } -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{linear_reg(penalty = varying()) \%>\% set_engine("glmnet") \%>\% set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## +## Main Arguments: +## penalty = varying() +## ## Computational engine: glmnet ## ## Model fit template: diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index fc05bb5ff..9ce67be95 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -86,12 +86,15 @@ call. For this type of model, the template of the fit calls are below. } } -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{logistic_reg(penalty = varying()) \%>\% set_engine("glmnet") \%>\% set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## +## Main Arguments: +## penalty = varying() +## ## Computational engine: glmnet ## ## Model fit template: diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index feb930900..301c3dc72 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -67,12 +67,15 @@ reloaded and reattached to the \code{parsnip} object. \section{Engine Details}{ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg(penalty = varying()) \%>\% set_engine("glmnet") \%>\% set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## +## Main Arguments: +## penalty = varying() +## ## Computational engine: glmnet ## ## Model fit template: diff --git a/man/rmd/linear-reg.Rmd b/man/rmd/linear-reg.Rmd index 03f1da44e..bf92a477a 100644 --- a/man/rmd/linear-reg.Rmd +++ b/man/rmd/linear-reg.Rmd @@ -17,7 +17,7 @@ linear_reg() %>% ## glmnet ```{r glmnet-csl} -linear_reg() %>% +linear_reg(penalty = varying()) %>% set_engine("glmnet") %>% set_mode("regression") %>% translate() diff --git a/man/rmd/logistic-reg.Rmd b/man/rmd/logistic-reg.Rmd index 656e24f70..e7e343ddf 100644 --- a/man/rmd/logistic-reg.Rmd +++ b/man/rmd/logistic-reg.Rmd @@ -18,7 +18,7 @@ logistic_reg() %>% ## glmnet ```{r glmnet-csl} -logistic_reg() %>% +logistic_reg(penalty = varying()) %>% set_engine("glmnet") %>% set_mode("classification") %>% translate() diff --git a/man/rmd/multinom-reg.Rmd b/man/rmd/multinom-reg.Rmd index 5d08847d2..12f2a1b59 100644 --- a/man/rmd/multinom-reg.Rmd +++ b/man/rmd/multinom-reg.Rmd @@ -9,7 +9,7 @@ For this type of model, the template of the fit calls are below. ## glmnet ```{r glmnet-cls} -multinom_reg() %>% +multinom_reg(penalty = varying()) %>% set_engine("glmnet") %>% set_mode("classification") %>% translate() diff --git a/man/set_engine.Rd b/man/set_engine.Rd index 889451351..d19d67293 100644 --- a/man/set_engine.Rd +++ b/man/set_engine.Rd @@ -26,7 +26,7 @@ to fit the model, along with any arguments specific to that software. \examples{ # First, set general arguments using the standardized names mod <- - logistic_reg(mixture = 1/3) \%>\% + logistic_reg(penalty = 0.01, mixture = 1/3) \%>\% # now say how you want to fit the model and another other options set_engine("glmnet", nlambda = 10) translate(mod, engine = "glmnet") From 0d03d3303b002f2e88c1e850d29a67381b394911 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Thu, 6 May 2021 20:01:42 -0600 Subject: [PATCH 4/7] One more example to fix for new glmnet error --- R/translate.R | 2 +- man/translate.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/translate.R b/R/translate.R index 2d697179d..4c2064db6 100644 --- a/R/translate.R +++ b/R/translate.R @@ -38,7 +38,7 @@ #' translate(lm_spec, engine = "spark") #' #' # with a placeholder for an unknown argument value: -#' translate(linear_reg(mixture = varying()), engine = "glmnet") +#' translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet") #' #' @export diff --git a/man/translate.Rd b/man/translate.Rd index 11bec9c97..c9a8b0400 100644 --- a/man/translate.Rd +++ b/man/translate.Rd @@ -52,6 +52,6 @@ translate(lm_spec, engine = "lm") translate(lm_spec, engine = "spark") # with a placeholder for an unknown argument value: -translate(linear_reg(mixture = varying()), engine = "glmnet") +translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet") } From cd4166ad6ee0219a98f68c2cd88c5f1b1a2dbd34 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Mon, 10 May 2021 17:22:21 -0600 Subject: [PATCH 5/7] Switch out for a single value of penalty in glmnet doc subsections --- man/linear_reg.Rd | 4 ++-- man/logistic_reg.Rd | 4 ++-- man/multinom_reg.Rd | 4 ++-- man/rmd/linear-reg.Rmd | 2 +- man/rmd/logistic-reg.Rmd | 2 +- man/rmd/multinom-reg.Rmd | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index 388a57325..cd031b12c 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -86,14 +86,14 @@ call. For this type of model, the template of the fit calls are below. } } -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{linear_reg(penalty = varying()) \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{linear_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## ## Main Arguments: -## penalty = varying() +## penalty = 0.1 ## ## Computational engine: glmnet ## diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index 9ce67be95..cdb6b73c8 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -86,14 +86,14 @@ call. For this type of model, the template of the fit calls are below. } } -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{logistic_reg(penalty = varying()) \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{logistic_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## ## Main Arguments: -## penalty = varying() +## penalty = 0.1 ## ## Computational engine: glmnet ## diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index 301c3dc72..a1ad4138e 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -67,14 +67,14 @@ reloaded and reattached to the \code{parsnip} object. \section{Engine Details}{ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg(penalty = varying()) \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## ## Main Arguments: -## penalty = varying() +## penalty = 0.1 ## ## Computational engine: glmnet ## diff --git a/man/rmd/linear-reg.Rmd b/man/rmd/linear-reg.Rmd index bf92a477a..18a8c66c4 100644 --- a/man/rmd/linear-reg.Rmd +++ b/man/rmd/linear-reg.Rmd @@ -17,7 +17,7 @@ linear_reg() %>% ## glmnet ```{r glmnet-csl} -linear_reg(penalty = varying()) %>% +linear_reg(penalty = 0.1) %>% set_engine("glmnet") %>% set_mode("regression") %>% translate() diff --git a/man/rmd/logistic-reg.Rmd b/man/rmd/logistic-reg.Rmd index e7e343ddf..74f483f07 100644 --- a/man/rmd/logistic-reg.Rmd +++ b/man/rmd/logistic-reg.Rmd @@ -18,7 +18,7 @@ logistic_reg() %>% ## glmnet ```{r glmnet-csl} -logistic_reg(penalty = varying()) %>% +logistic_reg(penalty = 0.1) %>% set_engine("glmnet") %>% set_mode("classification") %>% translate() diff --git a/man/rmd/multinom-reg.Rmd b/man/rmd/multinom-reg.Rmd index 12f2a1b59..3a48cd372 100644 --- a/man/rmd/multinom-reg.Rmd +++ b/man/rmd/multinom-reg.Rmd @@ -9,7 +9,7 @@ For this type of model, the template of the fit calls are below. ## glmnet ```{r glmnet-cls} -multinom_reg(penalty = varying()) %>% +multinom_reg(penalty = 0.1) %>% set_engine("glmnet") %>% set_mode("classification") %>% translate() From 5acfdae3436aeaf0f1d9ab25e9b3063d40d16a18 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 11 May 2021 13:10:09 -0400 Subject: [PATCH 6/7] remove unnecessary set_mode() calls --- man/contr_one_hot.Rd | 10 ++++++---- man/linear_reg.Rd | 5 ----- man/logistic_reg.Rd | 6 ------ man/multinom_reg.Rd | 4 ---- man/rmd/linear-reg.Rmd | 5 ----- man/rmd/logistic-reg.Rmd | 6 ------ man/rmd/multinom-reg.Rmd | 4 ---- 7 files changed, 6 insertions(+), 34 deletions(-) diff --git a/man/contr_one_hot.Rd b/man/contr_one_hot.Rd index df945bebd..a8c21d593 100644 --- a/man/contr_one_hot.Rd +++ b/man/contr_one_hot.Rd @@ -39,14 +39,16 @@ levels(penguins$species) }\if{html}{\out{
}}\preformatted{## [1] "Biscoe" "Dream" "Torgersen" }\if{html}{\out{
}}\preformatted{model.matrix(~ species + island, data = penguins) \%>\% colnames() -}\if{html}{\out{
}}\preformatted{## [1] "(Intercept)" "speciesChinstrap" "speciesGentoo" "islandDream" "islandTorgersen" +}\if{html}{\out{
}}\preformatted{## [1] "(Intercept)" "speciesChinstrap" "speciesGentoo" "islandDream" +## [5] "islandTorgersen" } For a formula with no intercept, the first factor is expanded to indicators for \emph{all} factor levels but all other factors are expanded to all but one (as above):\if{html}{\out{
}}\preformatted{model.matrix(~ 0 + species + island, data = penguins) \%>\% colnames() -}\if{html}{\out{
}}\preformatted{## [1] "speciesAdelie" "speciesChinstrap" "speciesGentoo" "islandDream" "islandTorgersen" +}\if{html}{\out{
}}\preformatted{## [1] "speciesAdelie" "speciesChinstrap" "speciesGentoo" "islandDream" +## [5] "islandTorgersen" } For inference, this hybrid encoding can be problematic. @@ -59,8 +61,8 @@ options(contrasts = new_contr) model.matrix(~ species + island, data = penguins) \%>\% colnames() -}\if{html}{\out{
}}\preformatted{## [1] "(Intercept)" "speciesAdelie" "speciesChinstrap" "speciesGentoo" "islandBiscoe" -## [6] "islandDream" "islandTorgersen" +}\if{html}{\out{
}}\preformatted{## [1] "(Intercept)" "speciesAdelie" "speciesChinstrap" "speciesGentoo" +## [5] "islandBiscoe" "islandDream" "islandTorgersen" }\if{html}{\out{
}}\preformatted{options(contrasts = old_contr) }\if{html}{\out{
}} diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index cd031b12c..b11544e60 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -75,7 +75,6 @@ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. \subsection{lm}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("lm") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## @@ -88,7 +87,6 @@ call. For this type of model, the template of the fit calls are below. \subsection{glmnet}{\if{html}{\out{
}}\preformatted{linear_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## @@ -115,7 +113,6 @@ results. \subsection{stan}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("stan") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## @@ -138,7 +135,6 @@ predictive distribution as appropriate) is returned. \subsection{spark}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("spark") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## @@ -152,7 +148,6 @@ predictive distribution as appropriate) is returned. \subsection{keras}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("keras") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index cdb6b73c8..6dff2100d 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -74,7 +74,6 @@ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. \subsection{glm}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("glm") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -88,7 +87,6 @@ call. For this type of model, the template of the fit calls are below. \subsection{glmnet}{\if{html}{\out{
}}\preformatted{logistic_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -115,7 +113,6 @@ results. \subsection{LiblineaR}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("LiblineaR") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -138,7 +135,6 @@ parameter estimates. \subsection{stan}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("stan") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -161,7 +157,6 @@ predictive distribution as appropriate) is returned. \subsection{spark}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("spark") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -175,7 +170,6 @@ predictive distribution as appropriate) is returned. \subsection{keras}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("keras") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index a1ad4138e..8dc8b2e5c 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -69,7 +69,6 @@ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. \subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## @@ -96,7 +95,6 @@ results. \subsection{nnet}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% set_engine("nnet") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## @@ -110,7 +108,6 @@ results. \subsection{spark}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% set_engine("spark") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## @@ -124,7 +121,6 @@ results. \subsection{keras}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% set_engine("keras") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## diff --git a/man/rmd/linear-reg.Rmd b/man/rmd/linear-reg.Rmd index 18a8c66c4..378f71db8 100644 --- a/man/rmd/linear-reg.Rmd +++ b/man/rmd/linear-reg.Rmd @@ -10,7 +10,6 @@ Engines may have pre-set default arguments when executing the model fit call. Fo ```{r lm-reg} linear_reg() %>% set_engine("lm") %>% - set_mode("regression") %>% translate() ``` @@ -19,7 +18,6 @@ linear_reg() %>% ```{r glmnet-csl} linear_reg(penalty = 0.1) %>% set_engine("glmnet") %>% - set_mode("regression") %>% translate() ``` @@ -37,7 +35,6 @@ penalty results. ```{r stan-reg} linear_reg() %>% set_engine("stan") %>% - set_mode("regression") %>% translate() ``` @@ -55,7 +52,6 @@ returned. ```{r spark-reg} linear_reg() %>% set_engine("spark") %>% - set_mode("regression") %>% translate() ``` @@ -64,7 +60,6 @@ linear_reg() %>% ```{r keras-reg} linear_reg() %>% set_engine("keras") %>% - set_mode("regression") %>% translate() ``` diff --git a/man/rmd/logistic-reg.Rmd b/man/rmd/logistic-reg.Rmd index 74f483f07..b28b44391 100644 --- a/man/rmd/logistic-reg.Rmd +++ b/man/rmd/logistic-reg.Rmd @@ -11,7 +11,6 @@ For this type of model, the template of the fit calls are below. ```{r glm-reg} logistic_reg() %>% set_engine("glm") %>% - set_mode("classification") %>% translate() ``` @@ -20,7 +19,6 @@ logistic_reg() %>% ```{r glmnet-csl} logistic_reg(penalty = 0.1) %>% set_engine("glmnet") %>% - set_mode("classification") %>% translate() ``` @@ -38,7 +36,6 @@ penalty results. ```{r liblinear-reg} logistic_reg() %>% set_engine("LiblineaR") %>% - set_mode("classification") %>% translate() ``` @@ -54,7 +51,6 @@ regularized regression models do not, which will result in different parameter e ```{r stan-reg} logistic_reg() %>% set_engine("stan") %>% - set_mode("classification") %>% translate() ``` @@ -72,7 +68,6 @@ returned. ```{r spark-reg} logistic_reg() %>% set_engine("spark") %>% - set_mode("classification") %>% translate() ``` @@ -81,7 +76,6 @@ logistic_reg() %>% ```{r keras-reg} logistic_reg() %>% set_engine("keras") %>% - set_mode("classification") %>% translate() ``` diff --git a/man/rmd/multinom-reg.Rmd b/man/rmd/multinom-reg.Rmd index 3a48cd372..2071db327 100644 --- a/man/rmd/multinom-reg.Rmd +++ b/man/rmd/multinom-reg.Rmd @@ -11,7 +11,6 @@ For this type of model, the template of the fit calls are below. ```{r glmnet-cls} multinom_reg(penalty = 0.1) %>% set_engine("glmnet") %>% - set_mode("classification") %>% translate() ``` @@ -29,7 +28,6 @@ penalty results. ```{r nnet-cls} multinom_reg() %>% set_engine("nnet") %>% - set_mode("classification") %>% translate() ``` @@ -38,7 +36,6 @@ multinom_reg() %>% ```{r spark-cls} multinom_reg() %>% set_engine("spark") %>% - set_mode("classification") %>% translate() ``` @@ -47,7 +44,6 @@ multinom_reg() %>% ```{r keras-cls} multinom_reg() %>% set_engine("keras") %>% - set_mode("classification") %>% translate() ``` From 2d9c1d317349f55298b17031eaefe5e5497e966b Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 11 May 2021 13:41:34 -0400 Subject: [PATCH 7/7] moved check to function; small mesage edits --- R/linear_reg.R | 8 +------- R/logistic_reg.R | 8 +------- R/misc.R | 11 ++++++++++- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/R/linear_reg.R b/R/linear_reg.R index 63e7c7872..372c7f503 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -112,13 +112,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) { # Since the `fit` information is gone for the penalty, we need to have an # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) - if (length(x$args$penalty) != 1) { - rlang::abort(c( - "For the glmnet engine, `penalty` must be a single number.", - glue::glue("There are {length(x$args$penalty)} values for `penalty`."), - "To try multiple values for total regularization, use the tune package." - )) - } + check_glmnet_penalty(x) } x diff --git a/R/logistic_reg.R b/R/logistic_reg.R index d3b0611d3..279e48002 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -115,13 +115,7 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) { # Since the `fit` information is gone for the penalty, we need to have an # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) - if (length(x$args$penalty) != 1) { - rlang::abort(c( - "For the glmnet engine, `penalty` must be a single number.", - glue::glue("There are {length(x$args$penalty)} values for `penalty`."), - "To try multiple values for total regularization, use the tune package." - )) - } + check_glmnet_penalty(x) } if (engine == "LiblineaR") { diff --git a/R/misc.R b/R/misc.R index 5d1fa9876..2332e6afc 100644 --- a/R/misc.R +++ b/R/misc.R @@ -323,4 +323,13 @@ stan_conf_int <- function(object, newdata) { rlang::eval_tidy(fn) } - +check_glmnet_penalty <- function(x) { + if (length(x$args$penalty) != 1) { + rlang::abort(c( + "For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).", + glue::glue("There are {length(x$args$penalty)} values for `penalty`."), + "To try multiple values for total regularization, use the tune package.", + "To predict multiple penalties, use `multi_predict()`" + )) + } +}