@@ -191,34 +191,42 @@ update.rand_forest <-
191191translate.rand_forest <- function (x , engine , ... ) {
192192 x <- translate.default(x , engine , ... )
193193
194+ # slightly cleaner code using
195+ arg_vals <- x $ method $ fit $ args
196+
194197 if (x $ engine == " spark" ) {
195- if (x $ mode == " unknown" )
198+ if (x $ mode == " unknown" ) {
196199 stop(
197200 " For spark random forests models, the mode cannot be 'unknown' " ,
198201 " if the specification is to be translated." ,
199202 call. = FALSE
200203 )
201- else
202- x $ method $ fit $ args $ type <- x $ mode
203-
204- # See "Details" in ?ml_random_forest_classifier
205- if (is.numeric(x $ method $ fit $ args $ feature_subset_strategy ))
206- x $ method $ fit $ args $ feature_subset_strategy <-
207- paste(x $ method $ fit $ args $ feature_subset_strategy )
204+ } else {
205+ arg_vals $ type <- x $ mode
206+ }
208207
208+ # See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
209+ # should be character even if it contains a number.
210+ if (any(names(arg_vals ) == " feature_subset_strategy" ) &&
211+ isTRUE(is.numeric(quo_get_expr(arg_vals $ feature_subset_strategy )))) {
212+ arg_vals $ feature_subset_strategy <-
213+ paste(quo_get_expr(arg_vals $ feature_subset_strategy ))
214+ }
209215 }
210216
211217 # add checks to error trap or change things for this method
212218 if (x $ engine == " ranger" ) {
213- if (any(names(x $ method $ fit $ args ) == " importance" ))
214- if (is.logical(x $ method $ fit $ args $ importance ))
219+ if (any(names(arg_vals ) == " importance" ))
220+ if (isTRUE( is.logical(quo_get_expr( arg_vals $ importance )) ))
215221 stop(" `importance` should be a character value. See ?ranger::ranger." ,
216222 call. = FALSE )
217223 # unless otherwise specified, classification models are probability forests
218- if (x $ mode == " classification" && ! any(names(x $ method $ fit $ args ) == " probability" ))
219- x $ method $ fit $ args $ probability <- TRUE
224+ if (x $ mode == " classification" && ! any(names(arg_vals ) == " probability" ))
225+ arg_vals $ probability <- TRUE
220226
221227 }
228+ x $ method $ fit $ args <- arg_vals
229+
222230 x
223231}
224232
0 commit comments