diff --git a/R/engines.R b/R/engines.R index d96d3e044..913cf8a1e 100644 --- a/R/engines.R +++ b/R/engines.R @@ -82,7 +82,7 @@ load_libs <- function(x, quiet, attach = FALSE) { #' @examples #' # First, set general arguments using the standardized names #' mod <- -#' logistic_reg(mixture = 1/3) %>% +#' logistic_reg(penalty = 0.01, mixture = 1/3) %>% #' # now say how you want to fit the model and another other options #' set_engine("glmnet", nlambda = 10) #' translate(mod, engine = "glmnet") diff --git a/R/linear_reg.R b/R/linear_reg.R index f3137ef19..372c7f503 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -112,6 +112,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) { # Since the `fit` information is gone for the penalty, we need to have an # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) + check_glmnet_penalty(x) } x diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 4d7b2d3ab..279e48002 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -115,6 +115,7 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) { # Since the `fit` information is gone for the penalty, we need to have an # evaluated value for the parameter. x$args$penalty <- rlang::eval_tidy(x$args$penalty) + check_glmnet_penalty(x) } if (engine == "LiblineaR") { diff --git a/R/misc.R b/R/misc.R index 5d1fa9876..2332e6afc 100644 --- a/R/misc.R +++ b/R/misc.R @@ -323,4 +323,13 @@ stan_conf_int <- function(object, newdata) { rlang::eval_tidy(fn) } - +check_glmnet_penalty <- function(x) { + if (length(x$args$penalty) != 1) { + rlang::abort(c( + "For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).", + glue::glue("There are {length(x$args$penalty)} values for `penalty`."), + "To try multiple values for total regularization, use the tune package.", + "To predict multiple penalties, use `multi_predict()`" + )) + } +} diff --git a/R/translate.R b/R/translate.R index 2d697179d..4c2064db6 100644 --- a/R/translate.R +++ b/R/translate.R @@ -38,7 +38,7 @@ #' translate(lm_spec, engine = "spark") #' #' # with a placeholder for an unknown argument value: -#' translate(linear_reg(mixture = varying()), engine = "glmnet") +#' translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet") #' #' @export diff --git a/man/contr_one_hot.Rd b/man/contr_one_hot.Rd index df945bebd..a8c21d593 100644 --- a/man/contr_one_hot.Rd +++ b/man/contr_one_hot.Rd @@ -39,14 +39,16 @@ levels(penguins$species) }\if{html}{\out{}}\preformatted{## [1] "Biscoe" "Dream" "Torgersen" }\if{html}{\out{
}}\preformatted{model.matrix(~ species + island, data = penguins) \%>\% colnames() -}\if{html}{\out{
}}\preformatted{## [1] "(Intercept)" "speciesChinstrap" "speciesGentoo" "islandDream" "islandTorgersen" +}\if{html}{\out{}}\preformatted{## [1] "(Intercept)" "speciesChinstrap" "speciesGentoo" "islandDream" +## [5] "islandTorgersen" } For a formula with no intercept, the first factor is expanded to indicators for \emph{all} factor levels but all other factors are expanded to all but one (as above):\if{html}{\out{
}}\preformatted{model.matrix(~ 0 + species + island, data = penguins) \%>\% colnames() -}\if{html}{\out{
}}\preformatted{## [1] "speciesAdelie" "speciesChinstrap" "speciesGentoo" "islandDream" "islandTorgersen" +}\if{html}{\out{}}\preformatted{## [1] "speciesAdelie" "speciesChinstrap" "speciesGentoo" "islandDream" +## [5] "islandTorgersen" } For inference, this hybrid encoding can be problematic. @@ -59,8 +61,8 @@ options(contrasts = new_contr) model.matrix(~ species + island, data = penguins) \%>\% colnames() -}\if{html}{\out{}}\preformatted{## [1] "(Intercept)" "speciesAdelie" "speciesChinstrap" "speciesGentoo" "islandBiscoe" -## [6] "islandDream" "islandTorgersen" +}\if{html}{\out{}}\preformatted{## [1] "(Intercept)" "speciesAdelie" "speciesChinstrap" "speciesGentoo" +## [5] "islandBiscoe" "islandDream" "islandTorgersen" }\if{html}{\out{
}}\preformatted{options(contrasts = old_contr) }\if{html}{\out{
}} diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index 53057d647..b11544e60 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -75,7 +75,6 @@ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. \subsection{lm}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("lm") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## @@ -86,12 +85,14 @@ call. For this type of model, the template of the fit calls are below. } } -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{linear_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## +## Main Arguments: +## penalty = 0.1 +## ## Computational engine: glmnet ## ## Model fit template: @@ -112,7 +113,6 @@ results. \subsection{stan}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("stan") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## @@ -135,7 +135,6 @@ predictive distribution as appropriate) is returned. \subsection{spark}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("spark") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## @@ -149,7 +148,6 @@ predictive distribution as appropriate) is returned. \subsection{keras}{\if{html}{\out{
}}\preformatted{linear_reg() \%>\% set_engine("keras") \%>\% - set_mode("regression") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Linear Regression Model Specification (regression) ## diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index fc05bb5ff..6dff2100d 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -74,7 +74,6 @@ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. \subsection{glm}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("glm") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -86,12 +85,14 @@ call. For this type of model, the template of the fit calls are below. } } -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{logistic_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## +## Main Arguments: +## penalty = 0.1 +## ## Computational engine: glmnet ## ## Model fit template: @@ -112,7 +113,6 @@ results. \subsection{LiblineaR}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("LiblineaR") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -135,7 +135,6 @@ parameter estimates. \subsection{stan}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("stan") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -158,7 +157,6 @@ predictive distribution as appropriate) is returned. \subsection{spark}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("spark") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## @@ -172,7 +170,6 @@ predictive distribution as appropriate) is returned. \subsection{keras}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\% set_engine("keras") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification) ## diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index feb930900..8dc8b2e5c 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -67,12 +67,14 @@ reloaded and reattached to the \code{parsnip} object. \section{Engine Details}{ Engines may have pre-set default arguments when executing the model fit call. For this type of model, the template of the fit calls are below. -\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% +\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg(penalty = 0.1) \%>\% set_engine("glmnet") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## +## Main Arguments: +## penalty = 0.1 +## ## Computational engine: glmnet ## ## Model fit template: @@ -93,7 +95,6 @@ results. \subsection{nnet}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% set_engine("nnet") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## @@ -107,7 +108,6 @@ results. \subsection{spark}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% set_engine("spark") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## @@ -121,7 +121,6 @@ results. \subsection{keras}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\% set_engine("keras") \%>\% - set_mode("classification") \%>\% translate() }\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification) ## diff --git a/man/rmd/linear-reg.Rmd b/man/rmd/linear-reg.Rmd index 03f1da44e..378f71db8 100644 --- a/man/rmd/linear-reg.Rmd +++ b/man/rmd/linear-reg.Rmd @@ -10,16 +10,14 @@ Engines may have pre-set default arguments when executing the model fit call. Fo ```{r lm-reg} linear_reg() %>% set_engine("lm") %>% - set_mode("regression") %>% translate() ``` ## glmnet ```{r glmnet-csl} -linear_reg() %>% +linear_reg(penalty = 0.1) %>% set_engine("glmnet") %>% - set_mode("regression") %>% translate() ``` @@ -37,7 +35,6 @@ penalty results. ```{r stan-reg} linear_reg() %>% set_engine("stan") %>% - set_mode("regression") %>% translate() ``` @@ -55,7 +52,6 @@ returned. ```{r spark-reg} linear_reg() %>% set_engine("spark") %>% - set_mode("regression") %>% translate() ``` @@ -64,7 +60,6 @@ linear_reg() %>% ```{r keras-reg} linear_reg() %>% set_engine("keras") %>% - set_mode("regression") %>% translate() ``` diff --git a/man/rmd/logistic-reg.Rmd b/man/rmd/logistic-reg.Rmd index 656e24f70..b28b44391 100644 --- a/man/rmd/logistic-reg.Rmd +++ b/man/rmd/logistic-reg.Rmd @@ -11,16 +11,14 @@ For this type of model, the template of the fit calls are below. ```{r glm-reg} logistic_reg() %>% set_engine("glm") %>% - set_mode("classification") %>% translate() ``` ## glmnet ```{r glmnet-csl} -logistic_reg() %>% +logistic_reg(penalty = 0.1) %>% set_engine("glmnet") %>% - set_mode("classification") %>% translate() ``` @@ -38,7 +36,6 @@ penalty results. ```{r liblinear-reg} logistic_reg() %>% set_engine("LiblineaR") %>% - set_mode("classification") %>% translate() ``` @@ -54,7 +51,6 @@ regularized regression models do not, which will result in different parameter e ```{r stan-reg} logistic_reg() %>% set_engine("stan") %>% - set_mode("classification") %>% translate() ``` @@ -72,7 +68,6 @@ returned. ```{r spark-reg} logistic_reg() %>% set_engine("spark") %>% - set_mode("classification") %>% translate() ``` @@ -81,7 +76,6 @@ logistic_reg() %>% ```{r keras-reg} logistic_reg() %>% set_engine("keras") %>% - set_mode("classification") %>% translate() ``` diff --git a/man/rmd/multinom-reg.Rmd b/man/rmd/multinom-reg.Rmd index 5d08847d2..2071db327 100644 --- a/man/rmd/multinom-reg.Rmd +++ b/man/rmd/multinom-reg.Rmd @@ -9,9 +9,8 @@ For this type of model, the template of the fit calls are below. ## glmnet ```{r glmnet-cls} -multinom_reg() %>% +multinom_reg(penalty = 0.1) %>% set_engine("glmnet") %>% - set_mode("classification") %>% translate() ``` @@ -29,7 +28,6 @@ penalty results. ```{r nnet-cls} multinom_reg() %>% set_engine("nnet") %>% - set_mode("classification") %>% translate() ``` @@ -38,7 +36,6 @@ multinom_reg() %>% ```{r spark-cls} multinom_reg() %>% set_engine("spark") %>% - set_mode("classification") %>% translate() ``` @@ -47,7 +44,6 @@ multinom_reg() %>% ```{r keras-cls} multinom_reg() %>% set_engine("keras") %>% - set_mode("classification") %>% translate() ``` diff --git a/man/set_engine.Rd b/man/set_engine.Rd index 889451351..d19d67293 100644 --- a/man/set_engine.Rd +++ b/man/set_engine.Rd @@ -26,7 +26,7 @@ to fit the model, along with any arguments specific to that software. \examples{ # First, set general arguments using the standardized names mod <- - logistic_reg(mixture = 1/3) \%>\% + logistic_reg(penalty = 0.01, mixture = 1/3) \%>\% # now say how you want to fit the model and another other options set_engine("glmnet", nlambda = 10) translate(mod, engine = "glmnet") diff --git a/man/translate.Rd b/man/translate.Rd index 11bec9c97..c9a8b0400 100644 --- a/man/translate.Rd +++ b/man/translate.Rd @@ -52,6 +52,6 @@ translate(lm_spec, engine = "lm") translate(lm_spec, engine = "spark") # with a placeholder for an unknown argument value: -translate(linear_reg(mixture = varying()), engine = "glmnet") +translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet") } diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 4c0617cec..b839572e8 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -15,7 +15,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)] test_that('primary arguments', { basic <- linear_reg() basic_lm <- translate(basic %>% set_engine("lm")) - basic_glmnet <- translate(basic %>% set_engine("glmnet")) + expect_error( + basic_glmnet <- translate(basic %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" + ) basic_stan <- translate(basic %>% set_engine("stan")) basic_spark <- translate(basic %>% set_engine("spark")) expect_equal(basic_lm$method$fit$args, @@ -25,14 +28,6 @@ test_that('primary arguments', { weights = expr(missing_arg()) ) ) - expect_equal(basic_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - family = "gaussian" - ) - ) expect_equal(basic_stan$method$fit$args, list( formula = expr(missing_arg()), @@ -51,17 +46,11 @@ test_that('primary arguments', { ) mixture <- linear_reg(mixture = 0.128) - mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) - mixture_spark <- translate(mixture %>% set_engine("spark")) - expect_equal(mixture_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(0.128), - family = "gaussian" - ) + expect_error( + mixture_glmnet <- translate(mixture %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) + mixture_spark <- translate(mixture %>% set_engine("spark")) expect_equal(mixture_spark$method$fit$args, list( x = expr(missing_arg()), @@ -92,17 +81,11 @@ test_that('primary arguments', { ) mixture_v <- linear_reg(mixture = varying()) - mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) - mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) - expect_equal(mixture_v_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(varying()), - family = "gaussian" - ) + expect_error( + mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) + mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) expect_equal(mixture_v_spark$method$fit$args, list( x = expr(missing_arg()), @@ -125,7 +108,7 @@ test_that('engine arguments', { ) ) - glmnet_nlam <- linear_reg() %>% set_engine("glmnet", nlambda = 10) + glmnet_nlam <- linear_reg(penalty = 0.1) %>% set_engine("glmnet", nlambda = 10) expect_equal(translate(glmnet_nlam)$method$fit$args, list( x = expr(missing_arg()), diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index b19806dae..5b2c9df1b 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -16,7 +16,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)] test_that('primary arguments', { basic <- logistic_reg() basic_glm <- translate(basic %>% set_engine("glm")) - basic_glmnet <- translate(basic %>% set_engine("glmnet")) + expect_error( + basic_glmnet <- translate(basic %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" + ) basic_liblinear <- translate(basic %>% set_engine("LiblineaR")) basic_stan <- translate(basic %>% set_engine("stan")) basic_spark <- translate(basic %>% set_engine("spark")) @@ -28,14 +31,6 @@ test_that('primary arguments', { family = expr(stats::binomial) ) ) - expect_equal(basic_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - family = "binomial" - ) - ) expect_equal(basic_liblinear$method$fit$args, list( x = expr(missing_arg()), @@ -63,17 +58,11 @@ test_that('primary arguments', { ) mixture <- logistic_reg(mixture = 0.128) - mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) - mixture_spark <- translate(mixture %>% set_engine("spark")) - expect_equal(mixture_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(0.128), - family = "binomial" - ) + expect_error( + mixture_glmnet <- translate(mixture %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) + mixture_spark <- translate(mixture %>% set_engine("spark")) expect_equal(mixture_spark$method$fit$args, list( x = expr(missing_arg()), @@ -116,18 +105,12 @@ test_that('primary arguments', { ) mixture_v <- logistic_reg(mixture = varying()) - mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) + expect_error( + mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" + ) mixture_v_liblinear <- translate(mixture_v %>% set_engine("LiblineaR")) mixture_v_spark <- translate(mixture_v %>% set_engine("spark")) - expect_equal(mixture_v_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - alpha = new_empty_quosure(varying()), - family = "binomial" - ) - ) expect_equal(mixture_v_liblinear$method$fit$args, list( x = expr(missing_arg()), @@ -194,7 +177,7 @@ test_that('engine arguments', { ) ) - glmnet_nlam <- logistic_reg() + glmnet_nlam <- logistic_reg(penalty = 0.1) expect_equal( translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args, list( diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index a373159b6..6a1b0037d 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -13,17 +13,11 @@ hpc <- hpc_data[1:150, c(2:5, 8)] test_that('primary arguments', { basic <- multinom_reg() - basic_glmnet <- translate(basic %>% set_engine("glmnet")) - expect_equal(basic_glmnet$method$fit$args, - list( - x = expr(missing_arg()), - y = expr(missing_arg()), - weights = expr(missing_arg()), - family = "multinomial" - ) + expect_error( + basic_glmnet <- translate(basic %>% set_engine("glmnet")), + "For the glmnet engine, `penalty` must be a single" ) - - mixture <- multinom_reg(mixture = 0.128) + mixture <- multinom_reg(penalty = 0.1, mixture = 0.128) mixture_glmnet <- translate(mixture %>% set_engine("glmnet")) expect_equal(mixture_glmnet$method$fit$args, list( @@ -46,7 +40,7 @@ test_that('primary arguments', { ) ) - mixture_v <- multinom_reg(mixture = varying()) + mixture_v <- multinom_reg(penalty = 0.01, mixture = varying()) mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")) expect_equal(mixture_v_glmnet$method$fit$args, list( @@ -61,7 +55,7 @@ test_that('primary arguments', { }) test_that('engine arguments', { - glmnet_nlam <- multinom_reg() + glmnet_nlam <- multinom_reg(penalty = 0.01) expect_equal( translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args, list( @@ -117,5 +111,5 @@ test_that('bad input', { expect_error(multinom_reg(mode = "regression")) expect_error(translate(multinom_reg() %>% set_engine("wat?"))) expect_error(translate(multinom_reg() %>% set_engine())) - expect_warning(translate(multinom_reg() %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) + expect_warning(translate(multinom_reg(penalty = 0.01) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) })