}}\preformatted{logistic_reg(penalty = 0.1) \%>\%
set_engine("glmnet") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification)
##
+## Main Arguments:
+## penalty = 0.1
+##
## Computational engine: glmnet
##
## Model fit template:
@@ -112,7 +113,6 @@ results.
\subsection{LiblineaR}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\%
set_engine("LiblineaR") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification)
##
@@ -135,7 +135,6 @@ parameter estimates.
\subsection{stan}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\%
set_engine("stan") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification)
##
@@ -158,7 +157,6 @@ predictive distribution as appropriate) is returned.
\subsection{spark}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\%
set_engine("spark") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification)
##
@@ -172,7 +170,6 @@ predictive distribution as appropriate) is returned.
\subsection{keras}{\if{html}{\out{
}}\preformatted{logistic_reg() \%>\%
set_engine("keras") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Logistic Regression Model Specification (classification)
##
diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd
index feb930900..8dc8b2e5c 100644
--- a/man/multinom_reg.Rd
+++ b/man/multinom_reg.Rd
@@ -67,12 +67,14 @@ reloaded and reattached to the \code{parsnip} object.
\section{Engine Details}{
Engines may have pre-set default arguments when executing the model fit
call. For this type of model, the template of the fit calls are below.
-\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\%
+\subsection{glmnet}{\if{html}{\out{
}}\preformatted{multinom_reg(penalty = 0.1) \%>\%
set_engine("glmnet") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification)
##
+## Main Arguments:
+## penalty = 0.1
+##
## Computational engine: glmnet
##
## Model fit template:
@@ -93,7 +95,6 @@ results.
\subsection{nnet}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\%
set_engine("nnet") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification)
##
@@ -107,7 +108,6 @@ results.
\subsection{spark}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\%
set_engine("spark") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification)
##
@@ -121,7 +121,6 @@ results.
\subsection{keras}{\if{html}{\out{
}}\preformatted{multinom_reg() \%>\%
set_engine("keras") \%>\%
- set_mode("classification") \%>\%
translate()
}\if{html}{\out{
}}\preformatted{## Multinomial Regression Model Specification (classification)
##
diff --git a/man/rmd/linear-reg.Rmd b/man/rmd/linear-reg.Rmd
index 03f1da44e..378f71db8 100644
--- a/man/rmd/linear-reg.Rmd
+++ b/man/rmd/linear-reg.Rmd
@@ -10,16 +10,14 @@ Engines may have pre-set default arguments when executing the model fit call. Fo
```{r lm-reg}
linear_reg() %>%
set_engine("lm") %>%
- set_mode("regression") %>%
translate()
```
## glmnet
```{r glmnet-csl}
-linear_reg() %>%
+linear_reg(penalty = 0.1) %>%
set_engine("glmnet") %>%
- set_mode("regression") %>%
translate()
```
@@ -37,7 +35,6 @@ penalty results.
```{r stan-reg}
linear_reg() %>%
set_engine("stan") %>%
- set_mode("regression") %>%
translate()
```
@@ -55,7 +52,6 @@ returned.
```{r spark-reg}
linear_reg() %>%
set_engine("spark") %>%
- set_mode("regression") %>%
translate()
```
@@ -64,7 +60,6 @@ linear_reg() %>%
```{r keras-reg}
linear_reg() %>%
set_engine("keras") %>%
- set_mode("regression") %>%
translate()
```
diff --git a/man/rmd/logistic-reg.Rmd b/man/rmd/logistic-reg.Rmd
index 656e24f70..b28b44391 100644
--- a/man/rmd/logistic-reg.Rmd
+++ b/man/rmd/logistic-reg.Rmd
@@ -11,16 +11,14 @@ For this type of model, the template of the fit calls are below.
```{r glm-reg}
logistic_reg() %>%
set_engine("glm") %>%
- set_mode("classification") %>%
translate()
```
## glmnet
```{r glmnet-csl}
-logistic_reg() %>%
+logistic_reg(penalty = 0.1) %>%
set_engine("glmnet") %>%
- set_mode("classification") %>%
translate()
```
@@ -38,7 +36,6 @@ penalty results.
```{r liblinear-reg}
logistic_reg() %>%
set_engine("LiblineaR") %>%
- set_mode("classification") %>%
translate()
```
@@ -54,7 +51,6 @@ regularized regression models do not, which will result in different parameter e
```{r stan-reg}
logistic_reg() %>%
set_engine("stan") %>%
- set_mode("classification") %>%
translate()
```
@@ -72,7 +68,6 @@ returned.
```{r spark-reg}
logistic_reg() %>%
set_engine("spark") %>%
- set_mode("classification") %>%
translate()
```
@@ -81,7 +76,6 @@ logistic_reg() %>%
```{r keras-reg}
logistic_reg() %>%
set_engine("keras") %>%
- set_mode("classification") %>%
translate()
```
diff --git a/man/rmd/multinom-reg.Rmd b/man/rmd/multinom-reg.Rmd
index 5d08847d2..2071db327 100644
--- a/man/rmd/multinom-reg.Rmd
+++ b/man/rmd/multinom-reg.Rmd
@@ -9,9 +9,8 @@ For this type of model, the template of the fit calls are below.
## glmnet
```{r glmnet-cls}
-multinom_reg() %>%
+multinom_reg(penalty = 0.1) %>%
set_engine("glmnet") %>%
- set_mode("classification") %>%
translate()
```
@@ -29,7 +28,6 @@ penalty results.
```{r nnet-cls}
multinom_reg() %>%
set_engine("nnet") %>%
- set_mode("classification") %>%
translate()
```
@@ -38,7 +36,6 @@ multinom_reg() %>%
```{r spark-cls}
multinom_reg() %>%
set_engine("spark") %>%
- set_mode("classification") %>%
translate()
```
@@ -47,7 +44,6 @@ multinom_reg() %>%
```{r keras-cls}
multinom_reg() %>%
set_engine("keras") %>%
- set_mode("classification") %>%
translate()
```
diff --git a/man/set_engine.Rd b/man/set_engine.Rd
index 889451351..d19d67293 100644
--- a/man/set_engine.Rd
+++ b/man/set_engine.Rd
@@ -26,7 +26,7 @@ to fit the model, along with any arguments specific to that software.
\examples{
# First, set general arguments using the standardized names
mod <-
- logistic_reg(mixture = 1/3) \%>\%
+ logistic_reg(penalty = 0.01, mixture = 1/3) \%>\%
# now say how you want to fit the model and another other options
set_engine("glmnet", nlambda = 10)
translate(mod, engine = "glmnet")
diff --git a/man/translate.Rd b/man/translate.Rd
index 11bec9c97..c9a8b0400 100644
--- a/man/translate.Rd
+++ b/man/translate.Rd
@@ -52,6 +52,6 @@ translate(lm_spec, engine = "lm")
translate(lm_spec, engine = "spark")
# with a placeholder for an unknown argument value:
-translate(linear_reg(mixture = varying()), engine = "glmnet")
+translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet")
}
diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R
index 4c0617cec..b839572e8 100644
--- a/tests/testthat/test_linear_reg.R
+++ b/tests/testthat/test_linear_reg.R
@@ -15,7 +15,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
test_that('primary arguments', {
basic <- linear_reg()
basic_lm <- translate(basic %>% set_engine("lm"))
- basic_glmnet <- translate(basic %>% set_engine("glmnet"))
+ expect_error(
+ basic_glmnet <- translate(basic %>% set_engine("glmnet")),
+ "For the glmnet engine, `penalty` must be a single"
+ )
basic_stan <- translate(basic %>% set_engine("stan"))
basic_spark <- translate(basic %>% set_engine("spark"))
expect_equal(basic_lm$method$fit$args,
@@ -25,14 +28,6 @@ test_that('primary arguments', {
weights = expr(missing_arg())
)
)
- expect_equal(basic_glmnet$method$fit$args,
- list(
- x = expr(missing_arg()),
- y = expr(missing_arg()),
- weights = expr(missing_arg()),
- family = "gaussian"
- )
- )
expect_equal(basic_stan$method$fit$args,
list(
formula = expr(missing_arg()),
@@ -51,17 +46,11 @@ test_that('primary arguments', {
)
mixture <- linear_reg(mixture = 0.128)
- mixture_glmnet <- translate(mixture %>% set_engine("glmnet"))
- mixture_spark <- translate(mixture %>% set_engine("spark"))
- expect_equal(mixture_glmnet$method$fit$args,
- list(
- x = expr(missing_arg()),
- y = expr(missing_arg()),
- weights = expr(missing_arg()),
- alpha = new_empty_quosure(0.128),
- family = "gaussian"
- )
+ expect_error(
+ mixture_glmnet <- translate(mixture %>% set_engine("glmnet")),
+ "For the glmnet engine, `penalty` must be a single"
)
+ mixture_spark <- translate(mixture %>% set_engine("spark"))
expect_equal(mixture_spark$method$fit$args,
list(
x = expr(missing_arg()),
@@ -92,17 +81,11 @@ test_that('primary arguments', {
)
mixture_v <- linear_reg(mixture = varying())
- mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet"))
- mixture_v_spark <- translate(mixture_v %>% set_engine("spark"))
- expect_equal(mixture_v_glmnet$method$fit$args,
- list(
- x = expr(missing_arg()),
- y = expr(missing_arg()),
- weights = expr(missing_arg()),
- alpha = new_empty_quosure(varying()),
- family = "gaussian"
- )
+ expect_error(
+ mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")),
+ "For the glmnet engine, `penalty` must be a single"
)
+ mixture_v_spark <- translate(mixture_v %>% set_engine("spark"))
expect_equal(mixture_v_spark$method$fit$args,
list(
x = expr(missing_arg()),
@@ -125,7 +108,7 @@ test_that('engine arguments', {
)
)
- glmnet_nlam <- linear_reg() %>% set_engine("glmnet", nlambda = 10)
+ glmnet_nlam <- linear_reg(penalty = 0.1) %>% set_engine("glmnet", nlambda = 10)
expect_equal(translate(glmnet_nlam)$method$fit$args,
list(
x = expr(missing_arg()),
diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R
index b19806dae..5b2c9df1b 100644
--- a/tests/testthat/test_logistic_reg.R
+++ b/tests/testthat/test_logistic_reg.R
@@ -16,7 +16,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
test_that('primary arguments', {
basic <- logistic_reg()
basic_glm <- translate(basic %>% set_engine("glm"))
- basic_glmnet <- translate(basic %>% set_engine("glmnet"))
+ expect_error(
+ basic_glmnet <- translate(basic %>% set_engine("glmnet")),
+ "For the glmnet engine, `penalty` must be a single"
+ )
basic_liblinear <- translate(basic %>% set_engine("LiblineaR"))
basic_stan <- translate(basic %>% set_engine("stan"))
basic_spark <- translate(basic %>% set_engine("spark"))
@@ -28,14 +31,6 @@ test_that('primary arguments', {
family = expr(stats::binomial)
)
)
- expect_equal(basic_glmnet$method$fit$args,
- list(
- x = expr(missing_arg()),
- y = expr(missing_arg()),
- weights = expr(missing_arg()),
- family = "binomial"
- )
- )
expect_equal(basic_liblinear$method$fit$args,
list(
x = expr(missing_arg()),
@@ -63,17 +58,11 @@ test_that('primary arguments', {
)
mixture <- logistic_reg(mixture = 0.128)
- mixture_glmnet <- translate(mixture %>% set_engine("glmnet"))
- mixture_spark <- translate(mixture %>% set_engine("spark"))
- expect_equal(mixture_glmnet$method$fit$args,
- list(
- x = expr(missing_arg()),
- y = expr(missing_arg()),
- weights = expr(missing_arg()),
- alpha = new_empty_quosure(0.128),
- family = "binomial"
- )
+ expect_error(
+ mixture_glmnet <- translate(mixture %>% set_engine("glmnet")),
+ "For the glmnet engine, `penalty` must be a single"
)
+ mixture_spark <- translate(mixture %>% set_engine("spark"))
expect_equal(mixture_spark$method$fit$args,
list(
x = expr(missing_arg()),
@@ -116,18 +105,12 @@ test_that('primary arguments', {
)
mixture_v <- logistic_reg(mixture = varying())
- mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet"))
+ expect_error(
+ mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")),
+ "For the glmnet engine, `penalty` must be a single"
+ )
mixture_v_liblinear <- translate(mixture_v %>% set_engine("LiblineaR"))
mixture_v_spark <- translate(mixture_v %>% set_engine("spark"))
- expect_equal(mixture_v_glmnet$method$fit$args,
- list(
- x = expr(missing_arg()),
- y = expr(missing_arg()),
- weights = expr(missing_arg()),
- alpha = new_empty_quosure(varying()),
- family = "binomial"
- )
- )
expect_equal(mixture_v_liblinear$method$fit$args,
list(
x = expr(missing_arg()),
@@ -194,7 +177,7 @@ test_that('engine arguments', {
)
)
- glmnet_nlam <- logistic_reg()
+ glmnet_nlam <- logistic_reg(penalty = 0.1)
expect_equal(
translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args,
list(
diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R
index a373159b6..6a1b0037d 100644
--- a/tests/testthat/test_multinom_reg.R
+++ b/tests/testthat/test_multinom_reg.R
@@ -13,17 +13,11 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
test_that('primary arguments', {
basic <- multinom_reg()
- basic_glmnet <- translate(basic %>% set_engine("glmnet"))
- expect_equal(basic_glmnet$method$fit$args,
- list(
- x = expr(missing_arg()),
- y = expr(missing_arg()),
- weights = expr(missing_arg()),
- family = "multinomial"
- )
+ expect_error(
+ basic_glmnet <- translate(basic %>% set_engine("glmnet")),
+ "For the glmnet engine, `penalty` must be a single"
)
-
- mixture <- multinom_reg(mixture = 0.128)
+ mixture <- multinom_reg(penalty = 0.1, mixture = 0.128)
mixture_glmnet <- translate(mixture %>% set_engine("glmnet"))
expect_equal(mixture_glmnet$method$fit$args,
list(
@@ -46,7 +40,7 @@ test_that('primary arguments', {
)
)
- mixture_v <- multinom_reg(mixture = varying())
+ mixture_v <- multinom_reg(penalty = 0.01, mixture = varying())
mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet"))
expect_equal(mixture_v_glmnet$method$fit$args,
list(
@@ -61,7 +55,7 @@ test_that('primary arguments', {
})
test_that('engine arguments', {
- glmnet_nlam <- multinom_reg()
+ glmnet_nlam <- multinom_reg(penalty = 0.01)
expect_equal(
translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args,
list(
@@ -117,5 +111,5 @@ test_that('bad input', {
expect_error(multinom_reg(mode = "regression"))
expect_error(translate(multinom_reg() %>% set_engine("wat?")))
expect_error(translate(multinom_reg() %>% set_engine()))
- expect_warning(translate(multinom_reg() %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
+ expect_warning(translate(multinom_reg(penalty = 0.01) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
})