Skip to content

Commit

Permalink
Radically improved across() implementation
Browse files Browse the repository at this point in the history
Now handles `across()` using NSE approach that's more in keeping with the keeping with the rest of dbplyr's translation, so that functions without native translation are passed along. It also gains support for `.fns = NULL` and for translating functions.

Fixes #525. Fixes #534. Fixes #554.
  • Loading branch information
hadley committed Jan 21, 2021
1 parent e2c383a commit 37c0cfa
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 31 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
@@ -1,5 +1,9 @@
# dbplyr (development version)

* `across()` implementation has been rewritten to support more inputs:
it now translates formulas (#525), works with SQL functions that don't have
R translations (#534), and work with `NULL` (#554)

* `pull()` no longer `select()`s the result when there's already only
one variable (#562).

Expand Down
65 changes: 51 additions & 14 deletions R/partial-eval.R
Expand Up @@ -176,25 +176,14 @@ partial_eval_across <- function(call, vars, env) {
tbl <- as_tibble(rep_named(vars, list(logical())))
cols <- syms(vars)[tidyselect::eval_select(call$.cols, tbl, allow_rename = FALSE)]

.fns <- eval(call$.fns, env)

if (is.function(.fns)) {
.fns <- find_fun(.fns)
} else if (is.list(.fns)) {
.fns <- purrr::map_chr(.fns, find_fun)
} else if (is.character(.fns)) {
# as is
} else {
abort("Unsupported `.fns` for dbplyr::across()")
}
funs <- set_names(syms(.fns), .fns)
funs <- across_funs(call$.fns, env)

# Generate grid of expressions
out <- vector("list", length(cols) * length(.fns))
out <- vector("list", length(cols) * length(funs))
k <- 1
for (i in seq_along(cols)) {
for (j in seq_along(funs)) {
out[[k]] <- expr((!!funs[[j]])(!!cols[[i]], !!!call$...))
out[[k]] <- exec(funs[[j]], cols[[i]], !!!call$...)
k <- k + 1
}
}
Expand All @@ -204,6 +193,54 @@ partial_eval_across <- function(call, vars, env) {
out
}

across_funs <- function(funs, env = caller_env()) {
if (is.null(funs)) {
list(function(x, ...) x)
} else if (is_symbol(funs)) {
set_names(list(across_fun(funs, env)), as.character(funs))
} else if (is.character(funs)) {
names(funs)[names2(funs) == ""] <- funs
lapply(funs, across_fun, env)
} else if (is_call(funs, "~")) {
set_names(list(across_fun(funs, env)), expr_name(f_rhs(funs)))
} else if (is_call(funs, "list")) {
args <- rlang::exprs_auto_name(funs[-1])
lapply(args, across_fun, env)
} else if (!is.null(env)) {
# Try evaluating once, just in case
funs <- eval(funs, env)
across_funs(funs, NULL)
} else {
abort("`.fns` argument to dbplyr::across() must be a NULL, a function name, formula, or list")
}
}

across_fun <- function(fun, env) {
if (is_symbol(fun) || is_string(fun)) {
function(x, ...) call2(fun, x, ...)
} else if (is_call(fun, "~")) {
fun <- across_formula_fn(f_rhs(fun))
function(x, ...) expr_interp(fun, child_env(emptyenv(), .x = x))
} else {
abort(c(
".fns argument to dbplyr::across() contain a function name or a formula",
x = paste0("Problem with ", expr_deparse(fun))
))
}
}


across_formula_fn <- function(x) {
if (is_symbol(x, ".") || is_symbol(x, ".x")) {
quote(!!.x)
} else if (is_call(x)) {
x[-1] <- lapply(x[-1], across_formula_fn)
x
} else {
x
}
}

across_names <- function(cols, funs, names = NULL, env = parent.frame()) {
if (length(funs) == 1) {
names <- names %||% "{.col}"
Expand Down
54 changes: 45 additions & 9 deletions tests/testthat/_snaps/partial-eval.md
@@ -1,7 +1,7 @@
# across() translated to individual components
# across() translates character vectors

Code
lf %>% summarise(across(everything(), "log"))
lf %>% summarise(across(a:b, "log"))
Output
<SQL>
SELECT LN(`a`) AS `a`, LN(`b`) AS `b`
Expand All @@ -10,16 +10,25 @@
---

Code
lf %>% summarise(across(everything(), log))
lf %>% summarise(across(a:b, "log", base = 2))
Output
<SQL>
SELECT LN(`a`) AS `a`, LN(`b`) AS `b`
SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b`
FROM `df`

---

Code
lf %>% summarise(across(everything(), list(log)))
lf %>% summarise(across(a, c("log", "exp")))
Output
<SQL>
SELECT LN(`a`) AS `a_log`, EXP(`a`) AS `a_exp`
FROM `df`

# across() translates functions

Code
lf %>% summarise(across(a:b, log))
Output
<SQL>
SELECT LN(`a`) AS `a`, LN(`b`) AS `b`
Expand All @@ -28,7 +37,7 @@
---

Code
lf %>% summarise(across(everything(), "log", base = 2))
lf %>% summarise(across(a:b, log, base = 2))
Output
<SQL>
SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b`
Expand All @@ -37,18 +46,45 @@
---

Code
lf %>% summarise(across(everything(), c("log", "exp")))
lf %>% summarise(across(a:b, list(log, exp)))
Output
<SQL>
SELECT LN(`a`) AS `a_log`, EXP(`a`) AS `a_exp`, LN(`b`) AS `b_log`, EXP(`b`) AS `b_exp`
FROM `df`

# untranslatable functions are preserved

Code
lf %>% summarise(across(a:b, SQL_LOG))
Output
<SQL>
SELECT SQL_LOG(`a`) AS `a`, SQL_LOG(`b`) AS `b`
FROM `df`

# across() translates formulas

Code
lf %>% summarise(across(a:b, ~log(.x, 2)))
Output
<SQL>
SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b`
FROM `df`

---

Code
lf %>% summarise(across(everything(), c("log", "exp"), .names = "{.fn}_{.col}"))
lf %>% summarise(across(a:b, list(~log(.x, 2))))
Output
<SQL>
SELECT LOG(2.0, `a`) AS `a`, LOG(2.0, `b`) AS `b`
FROM `df`

# across() translates NULL

Code
lf %>% mutate(across(a:b))
Output
<SQL>
SELECT LN(`a`) AS `log_a`, EXP(`a`) AS `exp_a`, LN(`b`) AS `log_b`, EXP(`b`) AS `exp_b`
SELECT `a`, `b`
FROM `df`

48 changes: 40 additions & 8 deletions tests/testthat/test-partial-eval.R
Expand Up @@ -43,16 +43,48 @@ test_that("fails with multi-classes", {
})

# across() ----------------------------------------------------------------
# test partial_eval_across() indirectly via SQL generation

test_that("across() translated to individual components", {
# test partial_eval_across() indirectly via SQL generation
test_that("across() translates character vectors", {
lf <- lazy_frame(a = 1, b = 2)
expect_snapshot(lf %>% summarise(across(everything(), "log")))
expect_snapshot(lf %>% summarise(across(everything(), log)))
expect_snapshot(lf %>% summarise(across(everything(), list(log))))
expect_snapshot(lf %>% summarise(across(a:b, "log")))
expect_snapshot(lf %>% summarise(across(a:b, "log", base = 2)))

expect_snapshot(lf %>% summarise(across(everything(), "log", base = 2)))
expect_snapshot(lf %>% summarise(across(a, c("log", "exp"))))

expect_snapshot(lf %>% summarise(across(everything(), c("log", "exp"))))
expect_snapshot(lf %>% summarise(across(everything(), c("log", "exp"), .names = "{.fn}_{.col}")))
out <- lf %>% summarise(across(a:b, c(x = "log", y = "exp")))
expect_equal(colnames(out), c("a_x", "a_y", "b_x", "b_y"))
})

test_that("across() translates functions", {
lf <- lazy_frame(a = 1, b = 2)
expect_snapshot(lf %>% summarise(across(a:b, log)))
expect_snapshot(lf %>% summarise(across(a:b, log, base = 2)))

expect_snapshot(lf %>% summarise(across(a:b, list(log, exp))))

out <- lf %>% summarise(across(a:b, list(x = log, y = exp)))
expect_equal(colnames(out), c("a_x", "a_y", "b_x", "b_y"))
})

test_that("untranslatable functions are preserved", {
lf <- lazy_frame(a = 1, b = 2)
expect_snapshot(lf %>% summarise(across(a:b, SQL_LOG)))
})

test_that("across() translates formulas", {
lf <- lazy_frame(a = 1, b = 2)
expect_snapshot(lf %>% summarise(across(a:b, ~ log(.x, 2))))
expect_snapshot(lf %>% summarise(across(a:b, list(~ log(.x, 2)))))
})

test_that("across() translates NULL", {
lf <- lazy_frame(a = 1, b = 2)
expect_snapshot(lf %>% mutate(across(a:b)))
})

test_that("can control names", {
lf <- lazy_frame(a = 1, b = 2)
out <- lf %>% summarise(across(a:b, c("log", "exp"), .names = "{.fn}_{.col}"))
expect_equal(colnames(out), c("log_a", "exp_a", "log_b", "exp_b"))
})

0 comments on commit 37c0cfa

Please sign in to comment.