From 37c0cfab049c6479d728a6bf6a99255fb1a41bfc Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Thu, 21 Jan 2021 14:22:33 -0600 Subject: [PATCH] Radically improved across() implementation Now handles `across()` using NSE approach that's more in keeping with the keeping with the rest of dbplyr's translation, so that functions without native translation are passed along. It also gains support for `.fns = NULL` and for translating functions. Fixes #525. Fixes #534. Fixes #554. --- NEWS.md | 4 ++ R/partial-eval.R | 65 +++++++++++++++++++++------ tests/testthat/_snaps/partial-eval.md | 54 ++++++++++++++++++---- tests/testthat/test-partial-eval.R | 48 ++++++++++++++++---- 4 files changed, 140 insertions(+), 31 deletions(-) diff --git a/NEWS.md b/NEWS.md index d13ae79d8..2a33c02f7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # dbplyr (development version) +* `across()` implementation has been rewritten to support more inputs: + it now translates formulas (#525), works with SQL functions that don't have + R translations (#534), and work with `NULL` (#554) + * `pull()` no longer `select()`s the result when there's already only one variable (#562). diff --git a/R/partial-eval.R b/R/partial-eval.R index 4d768c731..b015cd9af 100644 --- a/R/partial-eval.R +++ b/R/partial-eval.R @@ -176,25 +176,14 @@ partial_eval_across <- function(call, vars, env) { tbl <- as_tibble(rep_named(vars, list(logical()))) cols <- syms(vars)[tidyselect::eval_select(call$.cols, tbl, allow_rename = FALSE)] - .fns <- eval(call$.fns, env) - - if (is.function(.fns)) { - .fns <- find_fun(.fns) - } else if (is.list(.fns)) { - .fns <- purrr::map_chr(.fns, find_fun) - } else if (is.character(.fns)) { - # as is - } else { - abort("Unsupported `.fns` for dbplyr::across()") - } - funs <- set_names(syms(.fns), .fns) + funs <- across_funs(call$.fns, env) # Generate grid of expressions - out <- vector("list", length(cols) * length(.fns)) + out <- vector("list", length(cols) * length(funs)) k <- 1 for (i in seq_along(cols)) { for (j in seq_along(funs)) { - out[[k]] <- expr((!!funs[[j]])(!!cols[[i]], !!!call$...)) + out[[k]] <- exec(funs[[j]], cols[[i]], !!!call$...) k <- k + 1 } } @@ -204,6 +193,54 @@ partial_eval_across <- function(call, vars, env) { out } +across_funs <- function(funs, env = caller_env()) { + if (is.null(funs)) { + list(function(x, ...) x) + } else if (is_symbol(funs)) { + set_names(list(across_fun(funs, env)), as.character(funs)) + } else if (is.character(funs)) { + names(funs)[names2(funs) == ""] <- funs + lapply(funs, across_fun, env) + } else if (is_call(funs, "~")) { + set_names(list(across_fun(funs, env)), expr_name(f_rhs(funs))) + } else if (is_call(funs, "list")) { + args <- rlang::exprs_auto_name(funs[-1]) + lapply(args, across_fun, env) + } else if (!is.null(env)) { + # Try evaluating once, just in case + funs <- eval(funs, env) + across_funs(funs, NULL) + } else { + abort("`.fns` argument to dbplyr::across() must be a NULL, a function name, formula, or list") + } +} + +across_fun <- function(fun, env) { + if (is_symbol(fun) || is_string(fun)) { + function(x, ...) call2(fun, x, ...) + } else if (is_call(fun, "~")) { + fun <- across_formula_fn(f_rhs(fun)) + function(x, ...) expr_interp(fun, child_env(emptyenv(), .x = x)) + } else { + abort(c( + ".fns argument to dbplyr::across() contain a function name or a formula", + x = paste0("Problem with ", expr_deparse(fun)) + )) + } +} + + +across_formula_fn <- function(x) { + if (is_symbol(x, ".") || is_symbol(x, ".x")) { + quote(!!.x) + } else if (is_call(x)) { + x[-1] <- lapply(x[-1], across_formula_fn) + x + } else { + x + } +} + across_names <- function(cols, funs, names = NULL, env = parent.frame()) { if (length(funs) == 1) { names <- names %||% "{.col}" diff --git a/tests/testthat/_snaps/partial-eval.md b/tests/testthat/_snaps/partial-eval.md index 5dcf0d5e2..78cee6ec8 100644 --- a/tests/testthat/_snaps/partial-eval.md +++ b/tests/testthat/_snaps/partial-eval.md @@ -1,7 +1,7 @@ -# across() translated to individual components +# across() translates character vectors Code - lf %>% summarise(across(everything(), "log")) + lf %>% summarise(across(a:b, "log")) Output SELECT LN(`a`) AS `a`, LN(`b`) AS `b` @@ -10,16 +10,25 @@ --- Code - lf %>% summarise(across(everything(), log)) + lf %>% summarise(across(a:b, "log", base = 2)) Output - SELECT LN(`a`) AS `a`, LN(`b`) AS `b` + SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b` FROM `df` --- Code - lf %>% summarise(across(everything(), list(log))) + lf %>% summarise(across(a, c("log", "exp"))) + Output + + SELECT LN(`a`) AS `a_log`, EXP(`a`) AS `a_exp` + FROM `df` + +# across() translates functions + + Code + lf %>% summarise(across(a:b, log)) Output SELECT LN(`a`) AS `a`, LN(`b`) AS `b` @@ -28,7 +37,7 @@ --- Code - lf %>% summarise(across(everything(), "log", base = 2)) + lf %>% summarise(across(a:b, log, base = 2)) Output SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b` @@ -37,18 +46,45 @@ --- Code - lf %>% summarise(across(everything(), c("log", "exp"))) + lf %>% summarise(across(a:b, list(log, exp))) Output SELECT LN(`a`) AS `a_log`, EXP(`a`) AS `a_exp`, LN(`b`) AS `b_log`, EXP(`b`) AS `b_exp` FROM `df` +# untranslatable functions are preserved + + Code + lf %>% summarise(across(a:b, SQL_LOG)) + Output + + SELECT SQL_LOG(`a`) AS `a`, SQL_LOG(`b`) AS `b` + FROM `df` + +# across() translates formulas + + Code + lf %>% summarise(across(a:b, ~log(.x, 2))) + Output + + SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b` + FROM `df` + --- Code - lf %>% summarise(across(everything(), c("log", "exp"), .names = "{.fn}_{.col}")) + lf %>% summarise(across(a:b, list(~log(.x, 2)))) + Output + + SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b` + FROM `df` + +# across() translates NULL + + Code + lf %>% mutate(across(a:b)) Output - SELECT LN(`a`) AS `log_a`, EXP(`a`) AS `exp_a`, LN(`b`) AS `log_b`, EXP(`b`) AS `exp_b` + SELECT `a`, `b` FROM `df` diff --git a/tests/testthat/test-partial-eval.R b/tests/testthat/test-partial-eval.R index e5958a711..fd856d868 100644 --- a/tests/testthat/test-partial-eval.R +++ b/tests/testthat/test-partial-eval.R @@ -43,16 +43,48 @@ test_that("fails with multi-classes", { }) # across() ---------------------------------------------------------------- +# test partial_eval_across() indirectly via SQL generation -test_that("across() translated to individual components", { - # test partial_eval_across() indirectly via SQL generation +test_that("across() translates character vectors", { lf <- lazy_frame(a = 1, b = 2) - expect_snapshot(lf %>% summarise(across(everything(), "log"))) - expect_snapshot(lf %>% summarise(across(everything(), log))) - expect_snapshot(lf %>% summarise(across(everything(), list(log)))) + expect_snapshot(lf %>% summarise(across(a:b, "log"))) + expect_snapshot(lf %>% summarise(across(a:b, "log", base = 2))) - expect_snapshot(lf %>% summarise(across(everything(), "log", base = 2))) + expect_snapshot(lf %>% summarise(across(a, c("log", "exp")))) - expect_snapshot(lf %>% summarise(across(everything(), c("log", "exp")))) - expect_snapshot(lf %>% summarise(across(everything(), c("log", "exp"), .names = "{.fn}_{.col}"))) + out <- lf %>% summarise(across(a:b, c(x = "log", y = "exp"))) + expect_equal(colnames(out), c("a_x", "a_y", "b_x", "b_y")) +}) + +test_that("across() translates functions", { + lf <- lazy_frame(a = 1, b = 2) + expect_snapshot(lf %>% summarise(across(a:b, log))) + expect_snapshot(lf %>% summarise(across(a:b, log, base = 2))) + + expect_snapshot(lf %>% summarise(across(a:b, list(log, exp)))) + + out <- lf %>% summarise(across(a:b, list(x = log, y = exp))) + expect_equal(colnames(out), c("a_x", "a_y", "b_x", "b_y")) +}) + +test_that("untranslatable functions are preserved", { + lf <- lazy_frame(a = 1, b = 2) + expect_snapshot(lf %>% summarise(across(a:b, SQL_LOG))) +}) + +test_that("across() translates formulas", { + lf <- lazy_frame(a = 1, b = 2) + expect_snapshot(lf %>% summarise(across(a:b, ~ log(.x, 2)))) + expect_snapshot(lf %>% summarise(across(a:b, list(~ log(.x, 2))))) +}) + +test_that("across() translates NULL", { + lf <- lazy_frame(a = 1, b = 2) + expect_snapshot(lf %>% mutate(across(a:b))) +}) + +test_that("can control names", { + lf <- lazy_frame(a = 1, b = 2) + out <- lf %>% summarise(across(a:b, c("log", "exp"), .names = "{.fn}_{.col}")) + expect_equal(colnames(out), c("log_a", "exp_a", "log_b", "exp_b")) })