Skip to content

Commit 53becc5

Browse files
authored
Merge pull request #90 from topepo/quosure-passthrough-davis
Quosure passthrough davis
2 parents 50a6737 + e6078e3 commit 53becc5

File tree

3 files changed

+113
-49
lines changed

3 files changed

+113
-49
lines changed

R/descriptors.R

Lines changed: 90 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -70,33 +70,48 @@ get_descr_df <- function(formula, data) {
7070
tmp_dat <- convert_form_to_xy_fit(formula, data, indicators = FALSE)
7171

7272
if(is.factor(tmp_dat$y)) {
73-
n_levs <- function() {
73+
.n_levs <- function() {
7474
table(tmp_dat$y, dnn = NULL)
7575
}
76-
} else n_levs <- function() { NA }
76+
} else .n_levs <- function() { NA }
7777

78-
n_cols <- function() {
78+
.n_cols <- function() {
7979
ncol(tmp_dat$x)
8080
}
8181

82-
n_preds <- function() {
82+
.n_preds <- function() {
8383
ncol(convert_form_to_xy_fit(formula, data, indicators = TRUE)$x)
8484
}
8585

86-
n_obs <- function() {
86+
.n_obs <- function() {
8787
nrow(data)
8888
}
8989

90-
n_facts <- function() {
90+
.n_facts <- function() {
9191
sum(vapply(tmp_dat$x, is.factor, logical(1)))
9292
}
9393

94+
.dat <- function() {
95+
data
96+
}
97+
98+
.x <- function() {
99+
tmp_dat$x
100+
}
101+
102+
.y <- function() {
103+
tmp_dat$y
104+
}
105+
94106
list(
95-
n_cols = n_cols,
96-
n_preds = n_preds,
97-
n_obs = n_obs,
98-
n_levs = n_levs,
99-
n_facts = n_facts
107+
.n_cols = .n_cols,
108+
.n_preds = .n_preds,
109+
.n_obs = .n_obs,
110+
.n_levs = .n_levs,
111+
.n_facts = .n_facts,
112+
.dat = .dat,
113+
.x = .x,
114+
.y = .y
100115
)
101116
}
102117

@@ -170,34 +185,78 @@ get_descr_spark <- function(formula, data) {
170185
y_vals <- as.table(y_vals)
171186
} else y_vals <- NA
172187

188+
obs <- dplyr::tally(data) %>% dplyr::pull()
189+
190+
.n_cols <- function() length(f_term_labels)
191+
.n_pred <- function() all_preds
192+
.n_obs <- function() obs
193+
.n_levs <- function() y_vals
194+
.n_facts <- function() factor_pred
195+
196+
# still need .x(), .y(), .dat() ?
197+
173198
list(
174-
cols = length(f_term_labels),
175-
preds = all_preds,
176-
obs = dplyr::tally(data) %>% dplyr::pull(),
177-
levs = y_vals,
178-
facts = factor_pred
199+
.n_cols = .n_cols,
200+
.n_preds = .n_preds,
201+
.n_obs = .n_obs,
202+
.n_levs = .n_levs,
203+
.n_facts = .n_facts #,
204+
# .dat = .dat,
205+
# .x = .x,
206+
# .y = .y
179207
)
180208
}
181209

182210
get_descr_xy <- function(x, y) {
211+
183212
if(is.factor(y)) {
184-
n_levs <- table(y, dnn = NULL)
185-
} else n_levs <- NA
186-
187-
n_cols <- ncol(x)
188-
n_preds <- ncol(x)
189-
n_obs <- nrow(x)
190-
n_facts <- if(is.data.frame(x))
191-
sum(vapply(x, is.factor, logical(1)))
192-
else
193-
sum(apply(x, 2, is.factor)) # would this always be zero?
213+
.n_levs <- function() {
214+
table(y, dnn = NULL)
215+
}
216+
} else n_levs <- function() { NA }
217+
218+
.n_cols <- function() {
219+
ncol(x)
220+
}
221+
222+
.n_preds <- function() {
223+
ncol(x)
224+
}
225+
226+
.n_obs <- function() {
227+
nrow(x)
228+
}
229+
230+
.n_facts <- function() {
231+
if(is.data.frame(x))
232+
sum(vapply(x, is.factor, logical(1)))
233+
else
234+
sum(apply(x, 2, is.factor)) # would this always be zero?
235+
}
236+
237+
.dat <- function() {
238+
x <- as.data.frame(x)
239+
x[[".y"]] <- y
240+
x
241+
}
242+
243+
.x <- function() {
244+
x
245+
}
246+
247+
.y <- function() {
248+
y
249+
}
194250

195251
list(
196-
cols = n_cols,
197-
preds = n_preds,
198-
obs = n_obs,
199-
levs = n_levs,
200-
facts = n_facts
252+
.n_cols = .n_cols,
253+
.n_preds = .n_preds,
254+
.n_obs = .n_obs,
255+
.n_levs = .n_levs,
256+
.n_facts = .n_facts,
257+
.dat = .dat,
258+
.x = .x,
259+
.y = .y
201260
)
202261
}
203262

R/fit.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,10 @@ fit_xy.model_spec <-
184184
) {
185185

186186
cl <- match.call(expand.dots = TRUE)
187-
eval_env <- rlang::env()
187+
eval_env <- rlang::new_environment(parent = rlang::base_env())
188188
eval_env$x <- x
189189
eval_env$y <- y
190-
fit_interface <-
191-
check_xy_interface(eval_env$x, eval_env$y, cl, object)
190+
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
192191
object$engine <- engine
193192
object <- check_engine(object)
194193

R/fit_helpers.R

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@ form_form <-
1515

1616
object <- check_mode(object, y_levels)
1717

18-
# check to see of there are any `expr` in the arguments then
19-
# run a function that evaluates the data and subs in the
20-
# values of the expressions. we would have to evaluate the
21-
# formula (perhaps with and without dummy variables) to get
22-
# the appropraite number of columns. (`..vars..` vs `..cols..`)
23-
# Perhaps use `convert_form_to_xy_fit` here to get the results.
18+
# embed descriptor functions in the quosure environments
19+
# for each of the args provided
2420

2521
if (make_descr(object)) {
2622
data_stats <- get_descr_form(env$formula, env$data)
@@ -83,6 +79,24 @@ xy_xy <- function(object, env, control, target = "none", ...) {
8379

8480
object <- check_mode(object, levels(env$y))
8581

82+
if (make_descr(object)) {
83+
data_stats <- get_descr_xy(env$x, env$y)
84+
85+
object$args <- purrr::map(object$args, ~{
86+
87+
.x_env <- rlang::quo_get_env(.x)
88+
89+
if(identical(.x_env, rlang::empty_env())) {
90+
.x
91+
} else {
92+
.x_new_env <- rlang::env_bury(.x_env, !!! data_stats)
93+
rlang::quo_set_env(.x, .x_new_env)
94+
}
95+
96+
})
97+
98+
}
99+
86100
# sub in arguments to actual syntax for corresponding engine
87101
object <- translate(object, engine = object$engine)
88102

@@ -96,15 +110,6 @@ xy_xy <- function(object, env, control, target = "none", ...) {
96110
stop("Invalid data type target: ", target)
97111
)
98112

99-
if (make_descr(object)) {
100-
data_stats <- get_descr_xy(env$x, env$y)
101-
env$n_obs <- data_stats$obs
102-
env$n_cols <- data_stats$cols
103-
env$n_preds <- data_stats$preds
104-
env$n_levs <- data_stats$levs
105-
env$n_facts <- data_stats$facts
106-
}
107-
108113
fit_call <- make_call(
109114
fun = object$method$fit$func["fun"],
110115
ns = object$method$fit$func["pkg"],
@@ -126,6 +131,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
126131

127132
form_xy <- function(object, control, env,
128133
target = "none", ...) {
134+
129135
data_obj <- convert_form_to_xy_fit(
130136
formula = env$formula,
131137
data = env$data,

0 commit comments

Comments
 (0)