-
Notifications
You must be signed in to change notification settings - Fork 88
/
linear_reg.R
117 lines (103 loc) · 3.32 KB
/
linear_reg.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#' Linear regression
#'
#' @description
#'
#' `linear_reg()` defines a model that can predict numeric values from
#' predictors using a linear function. This function can fit regression models.
#'
#' \Sexpr[stage=render,results=rd]{parsnip:::make_engine_list("linear_reg")}
#'
#' More information on how \pkg{parsnip} is used for modeling is at
#' \url{https://www.tidymodels.org/}.
#'
#' @param mode A single character string for the type of model.
#' The only possible value for this model is "regression".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"lm"`.
#' @param penalty A non-negative number representing the total
#' amount of regularization (specific engines only).
#' @param mixture A number between zero and one (inclusive) denoting the
#' proportion of L1 regularization (i.e. lasso) in the model.
#'
#' * `mixture = 1` specifies a pure lasso model,
#' * `mixture = 0` specifies a ridge regression model, and
#' * `0 < mixture < 1` specifies an elastic net model, interpolating lasso and ridge.
#'
#' Available for specific engines only.
#'
#' @templateVar modeltype linear_reg
#' @template spec-details
#'
#' @template spec-references
#'
#' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("linear_reg")}
#'
#' @examplesIf !parsnip:::is_cran_check()
#' show_engines("linear_reg")
#'
#' linear_reg()
#' @export
linear_reg <-
function(mode = "regression",
engine = "lm",
penalty = NULL,
mixture = NULL) {
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
)
new_model_spec(
"linear_reg",
args = args,
eng_args = NULL,
mode = mode,
user_specified_mode = !missing(mode),
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}
#' @export
translate.linear_reg <- function(x, engine = x$engine, ...) {
x <- translate.default(x, engine, ...)
if (engine == "glmnet") {
# See https://parsnip.tidymodels.org/reference/glmnet-details.html
.check_glmnet_penalty_fit(x)
x <- set_glmnet_penalty_path(x)
# Since the `fit` information is gone for the penalty, we need to have an
# evaluated value for the parameter.
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
}
x
}
# ------------------------------------------------------------------------------
#' @method update linear_reg
#' @rdname parsnip_update
#' @export
update.linear_reg <-
function(object,
parameters = NULL,
penalty = NULL, mixture = NULL,
fresh = FALSE, ...) {
args <- list(
penalty = enquo(penalty),
mixture = enquo(mixture)
)
update_spec(
object = object,
parameters = parameters,
args_enquo_list = args,
fresh = fresh,
cls = "linear_reg",
...
)
}
# ------------------------------------------------------------------------------
#' @export
check_args.linear_reg <- function(object, call = rlang::caller_env()) {
args <- lapply(object$args, rlang::eval_tidy)
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")
invisible(object)
}