-
Notifications
You must be signed in to change notification settings - Fork 88
/
rand_forest.R
173 lines (147 loc) · 4.94 KB
/
rand_forest.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#' Random forest
#'
#' @description
#'
#' `rand_forest()` defines a model that creates a large number of decision
#' trees, each independent of the others. The final prediction uses all
#' predictions from the individual trees and combines them. This function can fit
#' classification, regression, and censored regression models.
#'
#' \Sexpr[stage=render,results=rd]{parsnip:::make_engine_list("rand_forest")}
#'
#' More information on how \pkg{parsnip} is used for modeling is at
#' \url{https://www.tidymodels.org/}.
#'
#' @inheritParams boost_tree
#' @param mtry An integer for the number of predictors that will
#' be randomly sampled at each split when creating the tree models.
#' @param trees An integer for the number of trees contained in
#' the ensemble.
#' @param min_n An integer for the minimum number of data points
#' in a node that are required for the node to be split further.
#'
#' @templateVar modeltype rand_forest
#' @template spec-details
#'
#' @template spec-references
#'
#' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("rand_forest")}
#'
#' @examplesIf !parsnip:::is_cran_check()
#' show_engines("rand_forest")
#'
#' rand_forest(mode = "classification", trees = 2000)
#' @export
rand_forest <-
function(mode = "unknown", engine = "ranger", mtry = NULL, trees = NULL, min_n = NULL) {
args <- list(
mtry = enquo(mtry),
trees = enquo(trees),
min_n = enquo(min_n)
)
new_model_spec(
"rand_forest",
args = args,
eng_args = NULL,
mode = mode,
user_specified_mode = !missing(mode),
method = NULL,
engine = engine,
user_specified_engine = !missing(engine)
)
}
# ------------------------------------------------------------------------------
#' @method update rand_forest
#' @rdname parsnip_update
#' @export
update.rand_forest <-
function(object,
parameters = NULL,
mtry = NULL, trees = NULL, min_n = NULL,
fresh = FALSE, ...) {
args <- list(
mtry = enquo(mtry),
trees = enquo(trees),
min_n = enquo(min_n)
)
update_spec(
object = object,
parameters = parameters,
args_enquo_list = args,
fresh = fresh,
cls = "rand_forest",
...
)
}
# ------------------------------------------------------------------------------
#' @export
translate.rand_forest <- function(x, engine = x$engine, ...) {
if (is.null(engine)) {
message("Used `engine = 'ranger'` for translation.")
engine <- "ranger"
}
x <- translate.default(x, engine, ...)
## -----------------------------------------------------------------------------
# slightly cleaner code using
arg_vals <- x$method$fit$args
if (x$engine == "spark") {
if (x$mode == "unknown") {
cli::cli_abort(
"For spark random forest models, the mode cannot
be {.val unknown} if the specification is to be translated."
)
} else {
arg_vals$type <- x$mode
}
# See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
# should be character even if it contains a number.
if (any(names(arg_vals) == "feature_subset_strategy") &&
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) {
arg_vals$feature_subset_strategy <-
paste(quo_get_expr(arg_vals$feature_subset_strategy))
}
}
# add checks to error trap or change things for this method
if (engine == "ranger") {
if (any(names(arg_vals) == "importance")) {
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) {
cli::cli_abort(
c(
"{.arg importance} should be a character value.",
"i" = "See ?ranger::ranger."
)
)
}
}
# unless otherwise specified, classification models are probability forests
if (x$mode == "classification" && !any(names(arg_vals) == "probability")) {
arg_vals$probability <- TRUE
}
}
## -----------------------------------------------------------------------------
# Protect some arguments based on data dimensions
if (any(names(arg_vals) == "mtry") & engine != "partykit") {
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(x))
}
if (any(names(arg_vals) == "min.node.size")) {
arg_vals$min.node.size <-
rlang::call2("min_rows", arg_vals$min.node.size, expr(x))
}
if (any(names(arg_vals) == "nodesize")) {
arg_vals$nodesize <-
rlang::call2("min_rows", arg_vals$nodesize, expr(x))
}
if (any(names(arg_vals) == "min_instances_per_node")) {
arg_vals$min_instances_per_node <-
rlang::call2("min_rows", arg_vals$min_instances_per_node, expr(x))
}
## -----------------------------------------------------------------------------
x$method$fit$args <- arg_vals
x
}
# ------------------------------------------------------------------------------
#' @export
check_args.rand_forest <- function(object, call = rlang::caller_env()) {
# move translate checks here?
invisible(object)
}