diff --git a/R/descriptors.R b/R/descriptors.R index 81c70623e..ebd495a37 100644 --- a/R/descriptors.R +++ b/R/descriptors.R @@ -70,33 +70,48 @@ get_descr_df <- function(formula, data) { tmp_dat <- convert_form_to_xy_fit(formula, data, indicators = FALSE) if(is.factor(tmp_dat$y)) { - n_levs <- function() { + .n_levs <- function() { table(tmp_dat$y, dnn = NULL) } - } else n_levs <- function() { NA } + } else .n_levs <- function() { NA } - n_cols <- function() { + .n_cols <- function() { ncol(tmp_dat$x) } - n_preds <- function() { + .n_preds <- function() { ncol(convert_form_to_xy_fit(formula, data, indicators = TRUE)$x) } - n_obs <- function() { + .n_obs <- function() { nrow(data) } - n_facts <- function() { + .n_facts <- function() { sum(vapply(tmp_dat$x, is.factor, logical(1))) } + .dat <- function() { + data + } + + .x <- function() { + tmp_dat$x + } + + .y <- function() { + tmp_dat$y + } + list( - n_cols = n_cols, - n_preds = n_preds, - n_obs = n_obs, - n_levs = n_levs, - n_facts = n_facts + .n_cols = .n_cols, + .n_preds = .n_preds, + .n_obs = .n_obs, + .n_levs = .n_levs, + .n_facts = .n_facts, + .dat = .dat, + .x = .x, + .y = .y ) } @@ -170,34 +185,78 @@ get_descr_spark <- function(formula, data) { y_vals <- as.table(y_vals) } else y_vals <- NA + obs <- dplyr::tally(data) %>% dplyr::pull() + + .n_cols <- function() length(f_term_labels) + .n_pred <- function() all_preds + .n_obs <- function() obs + .n_levs <- function() y_vals + .n_facts <- function() factor_pred + + # still need .x(), .y(), .dat() ? + list( - cols = length(f_term_labels), - preds = all_preds, - obs = dplyr::tally(data) %>% dplyr::pull(), - levs = y_vals, - facts = factor_pred + .n_cols = .n_cols, + .n_preds = .n_preds, + .n_obs = .n_obs, + .n_levs = .n_levs, + .n_facts = .n_facts #, + # .dat = .dat, + # .x = .x, + # .y = .y ) } get_descr_xy <- function(x, y) { + if(is.factor(y)) { - n_levs <- table(y, dnn = NULL) - } else n_levs <- NA - - n_cols <- ncol(x) - n_preds <- ncol(x) - n_obs <- nrow(x) - n_facts <- if(is.data.frame(x)) - sum(vapply(x, is.factor, logical(1))) - else - sum(apply(x, 2, is.factor)) # would this always be zero? + .n_levs <- function() { + table(y, dnn = NULL) + } + } else n_levs <- function() { NA } + + .n_cols <- function() { + ncol(x) + } + + .n_preds <- function() { + ncol(x) + } + + .n_obs <- function() { + nrow(x) + } + + .n_facts <- function() { + if(is.data.frame(x)) + sum(vapply(x, is.factor, logical(1))) + else + sum(apply(x, 2, is.factor)) # would this always be zero? + } + + .dat <- function() { + x <- as.data.frame(x) + x[[".y"]] <- y + x + } + + .x <- function() { + x + } + + .y <- function() { + y + } list( - cols = n_cols, - preds = n_preds, - obs = n_obs, - levs = n_levs, - facts = n_facts + .n_cols = .n_cols, + .n_preds = .n_preds, + .n_obs = .n_obs, + .n_levs = .n_levs, + .n_facts = .n_facts, + .dat = .dat, + .x = .x, + .y = .y ) } diff --git a/R/fit.R b/R/fit.R index 3a5349b19..a1351e0ed 100644 --- a/R/fit.R +++ b/R/fit.R @@ -184,11 +184,10 @@ fit_xy.model_spec <- ) { cl <- match.call(expand.dots = TRUE) - eval_env <- rlang::env() + eval_env <- rlang::new_environment(parent = rlang::base_env()) eval_env$x <- x eval_env$y <- y - fit_interface <- - check_xy_interface(eval_env$x, eval_env$y, cl, object) + fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object) object$engine <- engine object <- check_engine(object) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 2f3d140d5..947bbbf69 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -15,12 +15,8 @@ form_form <- object <- check_mode(object, y_levels) - # check to see of there are any `expr` in the arguments then - # run a function that evaluates the data and subs in the - # values of the expressions. we would have to evaluate the - # formula (perhaps with and without dummy variables) to get - # the appropraite number of columns. (`..vars..` vs `..cols..`) - # Perhaps use `convert_form_to_xy_fit` here to get the results. + # embed descriptor functions in the quosure environments + # for each of the args provided if (make_descr(object)) { data_stats <- get_descr_form(env$formula, env$data) @@ -83,6 +79,24 @@ xy_xy <- function(object, env, control, target = "none", ...) { object <- check_mode(object, levels(env$y)) + if (make_descr(object)) { + data_stats <- get_descr_xy(env$x, env$y) + + object$args <- purrr::map(object$args, ~{ + + .x_env <- rlang::quo_get_env(.x) + + if(identical(.x_env, rlang::empty_env())) { + .x + } else { + .x_new_env <- rlang::env_bury(.x_env, !!! data_stats) + rlang::quo_set_env(.x, .x_new_env) + } + + }) + + } + # sub in arguments to actual syntax for corresponding engine object <- translate(object, engine = object$engine) @@ -96,15 +110,6 @@ xy_xy <- function(object, env, control, target = "none", ...) { stop("Invalid data type target: ", target) ) - if (make_descr(object)) { - data_stats <- get_descr_xy(env$x, env$y) - env$n_obs <- data_stats$obs - env$n_cols <- data_stats$cols - env$n_preds <- data_stats$preds - env$n_levs <- data_stats$levs - env$n_facts <- data_stats$facts - } - fit_call <- make_call( fun = object$method$fit$func["fun"], ns = object$method$fit$func["pkg"], @@ -126,6 +131,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { form_xy <- function(object, control, env, target = "none", ...) { + data_obj <- convert_form_to_xy_fit( formula = env$formula, data = env$data,