Skip to content

Commit 20d6364

Browse files
committed
adapted model-specific translate code to quosures
1 parent a90ee98 commit 20d6364

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

R/mlp.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ update.mlp <-
213213
translate.mlp <- function(x, engine, ...) {
214214

215215
if (engine == "nnet") {
216-
if(is.null(x$args$hidden_units))
216+
if(isTRUE(is.null(quo_get_expr(x$args$hidden_units)))) {
217217
x$args$hidden_units <- 5
218+
}
218219
}
219220

220221
x <- translate.default(x, engine, ...)

R/rand_forest.R

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,34 +191,42 @@ update.rand_forest <-
191191
translate.rand_forest <- function(x, engine, ...) {
192192
x <- translate.default(x, engine, ...)
193193

194+
# slightly cleaner code using
195+
arg_vals <- x$method$fit$args
196+
194197
if (x$engine == "spark") {
195-
if (x$mode == "unknown")
198+
if (x$mode == "unknown") {
196199
stop(
197200
"For spark random forests models, the mode cannot be 'unknown' ",
198201
"if the specification is to be translated.",
199202
call. = FALSE
200203
)
201-
else
202-
x$method$fit$args$type <- x$mode
203-
204-
# See "Details" in ?ml_random_forest_classifier
205-
if (is.numeric(x$method$fit$args$feature_subset_strategy))
206-
x$method$fit$args$feature_subset_strategy <-
207-
paste(x$method$fit$args$feature_subset_strategy)
204+
} else {
205+
arg_vals$type <- x$mode
206+
}
208207

208+
# See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
209+
# should be character even if it contains a number.
210+
if (any(names(arg_vals) == "feature_subset_strategy") &&
211+
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) {
212+
arg_vals$feature_subset_strategy <-
213+
paste(quo_get_expr(arg_vals$feature_subset_strategy))
214+
}
209215
}
210216

211217
# add checks to error trap or change things for this method
212218
if (x$engine == "ranger") {
213-
if (any(names(x$method$fit$args) == "importance"))
214-
if (is.logical(x$method$fit$args$importance))
219+
if (any(names(arg_vals) == "importance"))
220+
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance))))
215221
stop("`importance` should be a character value. See ?ranger::ranger.",
216222
call. = FALSE)
217223
# unless otherwise specified, classification models are probability forests
218-
if (x$mode == "classification" && !any(names(x$method$fit$args) == "probability"))
219-
x$method$fit$args$probability <- TRUE
224+
if (x$mode == "classification" && !any(names(arg_vals) == "probability"))
225+
arg_vals$probability <- TRUE
220226

221227
}
228+
x$method$fit$args <- arg_vals
229+
222230
x
223231
}
224232

0 commit comments

Comments
 (0)