Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 90 additions & 31 deletions R/descriptors.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,33 +70,48 @@ get_descr_df <- function(formula, data) {
tmp_dat <- convert_form_to_xy_fit(formula, data, indicators = FALSE)

if(is.factor(tmp_dat$y)) {
n_levs <- function() {
.n_levs <- function() {
table(tmp_dat$y, dnn = NULL)
}
} else n_levs <- function() { NA }
} else .n_levs <- function() { NA }

n_cols <- function() {
.n_cols <- function() {
ncol(tmp_dat$x)
}

n_preds <- function() {
.n_preds <- function() {
ncol(convert_form_to_xy_fit(formula, data, indicators = TRUE)$x)
}

n_obs <- function() {
.n_obs <- function() {
nrow(data)
}

n_facts <- function() {
.n_facts <- function() {
sum(vapply(tmp_dat$x, is.factor, logical(1)))
}

.dat <- function() {
data
}

.x <- function() {
tmp_dat$x
}

.y <- function() {
tmp_dat$y
}

list(
n_cols = n_cols,
n_preds = n_preds,
n_obs = n_obs,
n_levs = n_levs,
n_facts = n_facts
.n_cols = .n_cols,
.n_preds = .n_preds,
.n_obs = .n_obs,
.n_levs = .n_levs,
.n_facts = .n_facts,
.dat = .dat,
.x = .x,
.y = .y
)
}

Expand Down Expand Up @@ -170,34 +185,78 @@ get_descr_spark <- function(formula, data) {
y_vals <- as.table(y_vals)
} else y_vals <- NA

obs <- dplyr::tally(data) %>% dplyr::pull()

.n_cols <- function() length(f_term_labels)
.n_pred <- function() all_preds
.n_obs <- function() obs
.n_levs <- function() y_vals
.n_facts <- function() factor_pred

# still need .x(), .y(), .dat() ?

list(
cols = length(f_term_labels),
preds = all_preds,
obs = dplyr::tally(data) %>% dplyr::pull(),
levs = y_vals,
facts = factor_pred
.n_cols = .n_cols,
.n_preds = .n_preds,
.n_obs = .n_obs,
.n_levs = .n_levs,
.n_facts = .n_facts #,
# .dat = .dat,
# .x = .x,
# .y = .y
)
}

get_descr_xy <- function(x, y) {

if(is.factor(y)) {
n_levs <- table(y, dnn = NULL)
} else n_levs <- NA

n_cols <- ncol(x)
n_preds <- ncol(x)
n_obs <- nrow(x)
n_facts <- if(is.data.frame(x))
sum(vapply(x, is.factor, logical(1)))
else
sum(apply(x, 2, is.factor)) # would this always be zero?
.n_levs <- function() {
table(y, dnn = NULL)
}
} else n_levs <- function() { NA }

.n_cols <- function() {
ncol(x)
}

.n_preds <- function() {
ncol(x)
}

.n_obs <- function() {
nrow(x)
}

.n_facts <- function() {
if(is.data.frame(x))
sum(vapply(x, is.factor, logical(1)))
else
sum(apply(x, 2, is.factor)) # would this always be zero?
}

.dat <- function() {
x <- as.data.frame(x)
x[[".y"]] <- y
x
}

.x <- function() {
x
}

.y <- function() {
y
}

list(
cols = n_cols,
preds = n_preds,
obs = n_obs,
levs = n_levs,
facts = n_facts
.n_cols = .n_cols,
.n_preds = .n_preds,
.n_obs = .n_obs,
.n_levs = .n_levs,
.n_facts = .n_facts,
.dat = .dat,
.x = .x,
.y = .y
)
}

Expand Down
5 changes: 2 additions & 3 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,10 @@ fit_xy.model_spec <-
) {

cl <- match.call(expand.dots = TRUE)
eval_env <- rlang::env()
eval_env <- rlang::new_environment(parent = rlang::base_env())
eval_env$x <- x
eval_env$y <- y
fit_interface <-
check_xy_interface(eval_env$x, eval_env$y, cl, object)
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
object$engine <- engine
object <- check_engine(object)

Expand Down
36 changes: 21 additions & 15 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ form_form <-

object <- check_mode(object, y_levels)

# check to see of there are any `expr` in the arguments then
# run a function that evaluates the data and subs in the
# values of the expressions. we would have to evaluate the
# formula (perhaps with and without dummy variables) to get
# the appropraite number of columns. (`..vars..` vs `..cols..`)
# Perhaps use `convert_form_to_xy_fit` here to get the results.
# embed descriptor functions in the quosure environments
# for each of the args provided

if (make_descr(object)) {
data_stats <- get_descr_form(env$formula, env$data)
Expand Down Expand Up @@ -83,6 +79,24 @@ xy_xy <- function(object, env, control, target = "none", ...) {

object <- check_mode(object, levels(env$y))

if (make_descr(object)) {
data_stats <- get_descr_xy(env$x, env$y)

object$args <- purrr::map(object$args, ~{

.x_env <- rlang::quo_get_env(.x)

if(identical(.x_env, rlang::empty_env())) {
.x
} else {
.x_new_env <- rlang::env_bury(.x_env, !!! data_stats)
rlang::quo_set_env(.x, .x_new_env)
}

})

}

# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)

Expand All @@ -96,15 +110,6 @@ xy_xy <- function(object, env, control, target = "none", ...) {
stop("Invalid data type target: ", target)
)

if (make_descr(object)) {
data_stats <- get_descr_xy(env$x, env$y)
env$n_obs <- data_stats$obs
env$n_cols <- data_stats$cols
env$n_preds <- data_stats$preds
env$n_levs <- data_stats$levs
env$n_facts <- data_stats$facts
}

fit_call <- make_call(
fun = object$method$fit$func["fun"],
ns = object$method$fit$func["pkg"],
Expand All @@ -126,6 +131,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {

form_xy <- function(object, control, env,
target = "none", ...) {

data_obj <- convert_form_to_xy_fit(
formula = env$formula,
data = env$data,
Expand Down