Skip to content

Commit 50a6737

Browse files
committed
Enough changes to get rand_forest with a formula interface working
1 parent 543094e commit 50a6737

File tree

5 files changed

+94
-55
lines changed

5 files changed

+94
-55
lines changed

R/descriptors.R

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@
5050
#'
5151
#' rand_forest(mode = "classification", mtry = expr(n_cols - 2))
5252
#' }
53-
#'
53+
#'
5454
#' When no instance of `expr` is found in any of the argument
5555
#' values, the descriptor calculation code will not be executed.
56-
#'
56+
#'
5757
NULL
5858

5959
get_descr_form <- function(formula, data) {
@@ -66,24 +66,37 @@ get_descr_form <- function(formula, data) {
6666
}
6767

6868
get_descr_df <- function(formula, data) {
69-
69+
7070
tmp_dat <- convert_form_to_xy_fit(formula, data, indicators = FALSE)
71-
71+
7272
if(is.factor(tmp_dat$y)) {
73-
n_levs <- table(tmp_dat$y, dnn = NULL)
74-
} else n_levs <- NA
75-
76-
n_cols <- ncol(tmp_dat$x)
77-
n_preds <- ncol(convert_form_to_xy_fit(formula, data, indicators = TRUE)$x)
78-
n_obs <- nrow(data)
79-
n_facts <- sum(vapply(tmp_dat$x, is.factor, logical(1)))
80-
73+
n_levs <- function() {
74+
table(tmp_dat$y, dnn = NULL)
75+
}
76+
} else n_levs <- function() { NA }
77+
78+
n_cols <- function() {
79+
ncol(tmp_dat$x)
80+
}
81+
82+
n_preds <- function() {
83+
ncol(convert_form_to_xy_fit(formula, data, indicators = TRUE)$x)
84+
}
85+
86+
n_obs <- function() {
87+
nrow(data)
88+
}
89+
90+
n_facts <- function() {
91+
sum(vapply(tmp_dat$x, is.factor, logical(1)))
92+
}
93+
8194
list(
82-
cols = n_cols,
83-
preds = n_preds,
84-
obs = n_obs,
85-
levs = n_levs,
86-
facts = n_facts
95+
n_cols = n_cols,
96+
n_preds = n_preds,
97+
n_obs = n_obs,
98+
n_levs = n_levs,
99+
n_facts = n_facts
87100
)
88101
}
89102

@@ -93,9 +106,9 @@ get_descr_df <- function(formula, data) {
93106
#' @importFrom rlang syms sym
94107
#' @importFrom utils head
95108
get_descr_spark <- function(formula, data) {
96-
109+
97110
all_vars <- all.vars(formula)
98-
111+
99112
if("." %in% all_vars){
100113
tmpdata <- dplyr::collect(head(data, 1000))
101114
f_terms <- stats::terms(formula, data = tmpdata)
@@ -106,11 +119,11 @@ get_descr_spark <- function(formula, data) {
106119
term_data <- dplyr::select(data, !!! rlang::syms(f_cols))
107120
tmpdata <- dplyr::collect(head(term_data, 1000))
108121
}
109-
122+
110123
f_term_labels <- attr(f_terms, "term.labels")
111124
y_ind <- attr(f_terms, "response")
112125
y_col <- f_cols[y_ind]
113-
126+
114127
classes <- purrr::map(tmpdata, class)
115128
icats <- purrr::map_lgl(classes, ~.x == "character")
116129
cats <- classes[icats]
@@ -119,14 +132,14 @@ get_descr_spark <- function(formula, data) {
119132
cat_levels <- imap(
120133
cats,
121134
~{
122-
p <- dplyr::group_by(data, !! rlang::sym(.y))
135+
p <- dplyr::group_by(data, !! rlang::sym(.y))
123136
p <- dplyr::summarise(p)
124137
dplyr::pull(p)
125138
}
126-
)
139+
)
127140
numeric_pred <- length(f_term_labels) - length(cat_levels)
128-
129-
141+
142+
130143
if(length(cat_levels) > 0){
131144
n_dummies <- purrr::map_dbl(cat_levels, ~length(.x) - 1)
132145
n_dummies <- sum(n_dummies)
@@ -136,27 +149,27 @@ get_descr_spark <- function(formula, data) {
136149
factor_pred <- 0
137150
all_preds <- numeric_pred
138151
}
139-
152+
140153
out_cats <- classes[icats]
141154
out_cats <- out_cats[names(out_cats) == y_col]
142-
155+
143156
outs <- purrr::imap(
144157
out_cats,
145158
~{
146-
p <- dplyr::group_by(data, !! sym(.y))
147-
p <- dplyr::tally(p)
159+
p <- dplyr::group_by(data, !! sym(.y))
160+
p <- dplyr::tally(p)
148161
dplyr::collect(p)
149162
}
150-
)
151-
163+
)
164+
152165
if(length(outs) > 0){
153166
outs <- outs[[1]]
154167
y_vals <- purrr::as_vector(outs[,2])
155168
names(y_vals) <- purrr::as_vector(outs[,1])
156169
y_vals <- y_vals[order(names(y_vals))]
157170
y_vals <- as.table(y_vals)
158171
} else y_vals <- NA
159-
172+
160173
list(
161174
cols = length(f_term_labels),
162175
preds = all_preds,
@@ -170,15 +183,15 @@ get_descr_xy <- function(x, y) {
170183
if(is.factor(y)) {
171184
n_levs <- table(y, dnn = NULL)
172185
} else n_levs <- NA
173-
186+
174187
n_cols <- ncol(x)
175188
n_preds <- ncol(x)
176189
n_obs <- nrow(x)
177190
n_facts <- if(is.data.frame(x))
178191
sum(vapply(x, is.factor, logical(1)))
179192
else
180193
sum(apply(x, 2, is.factor)) # would this always be zero?
181-
194+
182195
list(
183196
cols = n_cols,
184197
preds = n_preds,

R/fit.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ fit.model_spec <-
103103
cl <- match.call(expand.dots = TRUE)
104104
# Create an environment with the evaluated argument objects. This will be
105105
# used when a model call is made later.
106-
eval_env <- rlang::env()
106+
107+
eval_env <- rlang::new_environment(parent = rlang::base_env())
107108
eval_env$data <- data
108109
eval_env$formula <- formula
109110
fit_interface <-
@@ -181,6 +182,7 @@ fit_xy.model_spec <-
181182
control = fit_control(),
182183
...
183184
) {
185+
184186
cl <- match.call(expand.dots = TRUE)
185187
eval_env <- rlang::env()
186188
eval_env$x <- x

R/fit_helpers.R

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,31 @@ form_form <-
1515

1616
object <- check_mode(object, y_levels)
1717

18+
# check to see of there are any `expr` in the arguments then
19+
# run a function that evaluates the data and subs in the
20+
# values of the expressions. we would have to evaluate the
21+
# formula (perhaps with and without dummy variables) to get
22+
# the appropraite number of columns. (`..vars..` vs `..cols..`)
23+
# Perhaps use `convert_form_to_xy_fit` here to get the results.
24+
25+
if (make_descr(object)) {
26+
data_stats <- get_descr_form(env$formula, env$data)
27+
28+
object$args <- purrr::map(object$args, ~{
29+
30+
.x_env <- rlang::quo_get_env(.x)
31+
32+
if(identical(.x_env, rlang::empty_env())) {
33+
.x
34+
} else {
35+
.x_new_env <- rlang::env_bury(.x_env, !!! data_stats)
36+
rlang::quo_set_env(.x, .x_new_env)
37+
}
38+
39+
})
40+
41+
}
42+
1843
# sub in arguments to actual syntax for corresponding engine
1944
object <- translate(object, engine = object$engine)
2045

@@ -28,22 +53,6 @@ form_form <-
2853
}
2954
fit_args$formula <- quote(formula)
3055

31-
# check to see of there are any `expr` in the arguments then
32-
# run a function that evaluates the data and subs in the
33-
# values of the expressions. we would have to evaluate the
34-
# formula (perhaps with and without dummy variables) to get
35-
# the appropraite number of columns. (`..vars..` vs `..cols..`)
36-
# Perhaps use `convert_form_to_xy_fit` here to get the results.
37-
38-
if (make_descr(object)) {
39-
data_stats <- get_descr_form(env$formula, env$data)
40-
env$n_obs <- data_stats$obs
41-
env$n_cols <- data_stats$cols
42-
env$n_preds <- data_stats$preds
43-
env$n_levs <- data_stats$levs
44-
env$n_facts <- data_stats$facts
45-
}
46-
4756
fit_call <- make_call(
4857
fun = object$method$fit$func["fun"],
4958
ns = object$method$fit$func["pkg"],

R/misc.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ model_printer <- function(x, ...) {
5656
non_null_args <- x$args[!vapply(x$args, null_value, lgl(1))]
5757
if (length(non_null_args) > 0) {
5858
cat("Main Arguments:\n")
59+
non_null_args <- map(non_null_args, convert_arg)
5960
cat(print_arg_list(non_null_args), "\n", sep = "")
6061
}
6162
if (length(x$others) > 0) {
6263
cat("Engine-Specific Arguments:\n")
64+
x$others <- map(x$others, convert_arg)
6365
cat(print_arg_list(x$others), "\n", sep = "")
6466
}
6567
if (!is.null(x$engine)) {
@@ -95,6 +97,8 @@ is_missing_arg <- function(x)
9597
#' @keywords internal
9698
#' @export
9799
show_call <- function(object) {
100+
object$method$fit$args <-
101+
map(object$method$fit$args, convert_arg)
98102
if (
99103
is.null(object$method$fit$func["pkg"]) ||
100104
is.na(object$method$fit$func["pkg"])
@@ -109,8 +113,17 @@ show_call <- function(object) {
109113
res
110114
}
111115

116+
convert_arg <- function(x) {
117+
if (is_quosure(x))
118+
quo_get_expr(x)
119+
else
120+
x
121+
}
122+
112123
make_call <- function(fun, ns, args, ...) {
113124

125+
#args <- map(args, convert_arg)
126+
114127
# remove any null or placeholders (`missing_args`) that remain
115128
discard <-
116129
vapply(args, function(x)

R/rand_forest.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@
103103

104104
rand_forest <-
105105
function(mode = "unknown",
106-
...,
107-
mtry = NULL, trees = NULL, min_n = NULL,
108-
others = list()) {
109-
check_empty_ellipse(...)
106+
mtry = NULL, trees = NULL, min_n = NULL, ...) {
107+
108+
others <- enquos(...)
109+
mtry <- enquo(mtry)
110+
trees <- enquo(trees)
111+
min_n <- enquo(min_n)
110112

111113
## TODO: make a utility function here
112114
if (!(mode %in% rand_forest_modes))

0 commit comments

Comments
 (0)