Skip to content

Commit

Permalink
Update R code for new join implementation.
Browse files Browse the repository at this point in the history
Closes #593
  • Loading branch information
hadley committed Sep 12, 2014
1 parent 977c2f2 commit 9cf29da
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 42 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Expand Up @@ -116,7 +116,7 @@
* joins (e.g. `left_join()`, `inner_join()`, `semi_join()`, `anti_join()`)
now allow you to join on different variables in `x` and `y` tables by
supplying a named vector to `by`. For example, `by = c("a" = "b")` joins
`x.a` to `y.b`. (Currently only supported in sql sources)
`x.a` to `y.b`.

* `order_by()` now works in conjunction with window functions in databases
that support them.
Expand Down
32 changes: 9 additions & 23 deletions R/dbi-s3.r
Expand Up @@ -261,37 +261,30 @@ sql_join.DBIConnection <- function(con, x, y, type = "inner", by = NULL, ...) {
stop("Unknown join type:", type, call. = FALSE)
)

by <- by %||% common_by(x, y)
if (!is.null(names(by))) {
by_x <- names(by)
by_y <- unname(by)
} else {
by_x <- by
by_y <- by
}
using <- all(by_x == by_y)
by <- common_by(by, x, y)
using <- all(by$x == by$y)

# Ensure tables have unique names
x_names <- auto_names(x$select)
y_names <- auto_names(y$select)
uniques <- unique_names(x_names, y_names, by_x[by_x == by_y])
uniques <- unique_names(x_names, y_names, by$x[by$x == by$y])

if (is.null(uniques)) {
sel_vars <- c(x_names, y_names)
} else {
x <- update(x, select = setNames(x$select, uniques$x))
y <- update(y, select = setNames(y$select, uniques$y))

by_x <- unname(uniques$x[by_x])
by_y <- unname(uniques$y[by_y])
by$x <- unname(uniques$x[by$x])
by$y <- unname(uniques$y[by$y])

sel_vars <- unique(c(uniques$x, uniques$y))
}

if (using) {
cond <- build_sql("USING ", lapply(by_x, ident), con = con)
cond <- build_sql("USING ", lapply(by$x, ident), con = con)
} else {
on <- sql_vector(paste0(sql_escape_ident(con, by_x), " = ", sql_escape_ident(con, by_y)),
on <- sql_vector(paste0(sql_escape_ident(con, by$x), " = ", sql_escape_ident(con, by$y)),
collapse = " AND ", parens = TRUE)
cond <- build_sql("ON ", on, con = con)
}
Expand All @@ -314,19 +307,12 @@ sql_semi_join <- function(con, x, y, anti = FALSE, by = NULL, ...) {
}
#' @export
sql_semi_join.DBIConnection <- function(con, x, y, anti = FALSE, by = NULL, ...) {
by <- by %||% common_by(x, y)
if (!is.null(names(by))) {
by_x <- names(by)
by_y <- unname(by)
} else {
by_x <- by
by_y <- by
}
by <- common_by(by, x, y)

left <- escape(ident("_LEFT"), con = con)
right <- escape(ident("_RIGHT"), con = con)
on <- sql_vector(paste0(
left, ".", sql_escape_ident(con, by_x), " = ", right, ".", sql_escape_ident(con, by_y)),
left, ".", sql_escape_ident(con, by$x), " = ", right, ".", sql_escape_ident(con, by$y)),
collapse = " AND ", parens = TRUE)

from <- build_sql(
Expand Down
17 changes: 9 additions & 8 deletions R/join-df.r
Expand Up @@ -37,31 +37,32 @@ NULL
#' @export
#' @rdname join.tbl_df
inner_join.tbl_df <- function(x, y, by = NULL, copy = FALSE, ...) {
by <- by %||% common_by(x, y)
by <- common_by(by, x, y)
y <- auto_copy(x, y, copy = copy)
inner_join_impl(x, y, by)

inner_join_impl(x, y, by$x, by$y)
}

#' @export
#' @rdname join.tbl_df
left_join.tbl_df <- function(x, y, by = NULL, copy = FALSE, ...) {
by <- by %||% common_by(x, y)
by <- common_by(by, x, y)
y <- auto_copy(x, y, copy = copy)
left_join_impl(x, y, by)
left_join_impl(x, y, by$x, by$y)
}

#' @export
#' @rdname join.tbl_df
semi_join.tbl_df <- function(x, y, by = NULL, copy = FALSE, ...) {
by <- by %||% common_by(x, y)
by <- common_by(by, x, y)
y <- auto_copy(x, y, copy = copy)
semi_join_impl(x, y, by)
semi_join_impl(x, y, by$x, by$y)
}

#' @export
#' @rdname join.tbl_df
anti_join.tbl_df <- function(x, y, by = NULL, copy = FALSE, ...) {
by <- by %||% common_by(x, y)
by <- common_by(by, x, y)
y <- auto_copy(x, y, copy = copy)
anti_join_impl(x, y, by)
anti_join_impl(x, y, by$x, by$y)
}
13 changes: 8 additions & 5 deletions R/join-dt.r
Expand Up @@ -31,13 +31,16 @@ NULL

join_dt <- function(op) {
template <- substitute(function(x, y, by = NULL, copy = FALSE, ...) {
by <- by %||% common_by(x, y)
by <- common_by(by, x, y)
if (!identical(by$x, by$y)) {
stop("Data table joins must be on same key", call. = FALSE)
}
y <- auto_copy(x, y, copy = copy)

x <- copy(x)
y <- copy(y)
setkeyv(x, by)
setkeyv(y, by)
setkeyv(x, by$x)
setkeyv(y, by$x)
out <- op
grouped_dt(out, groups(x))
})
Expand All @@ -49,11 +52,11 @@ join_dt <- function(op) {

#' @export
#' @rdname join.tbl_dt
inner_join.data.table <- join_dt(merge(x, y, by = by, allow.cartesian = TRUE))
inner_join.data.table <- join_dt(merge(x, y, by = by$x, allow.cartesian = TRUE))

#' @export
#' @rdname join.tbl_dt
left_join.data.table <- join_dt(merge(x, y, by = by, all.x = TRUE, allow.cartesian = TRUE))
left_join.data.table <- join_dt(merge(x, y, by = by$x, all.x = TRUE, allow.cartesian = TRUE))

#' @export
#' @rdname join.tbl_dt
Expand Down
15 changes: 13 additions & 2 deletions R/join.r
Expand Up @@ -77,13 +77,24 @@ anti_join <- function(x, y, by = NULL, copy = FALSE, ...) {
UseMethod("anti_join")
}

common_by <- function(x, y) {
common_by <- function(by = NULL, x, y) {
if (!is.null(by)) {
return(list(
x = names(by) %||% by,
y = unname(by)
))
}

by <- intersect(tbl_vars(x), tbl_vars(y))
if (length(by) == 0) {
stop("No common variables. Please specify `by` param.", call. = FALSE)
}
message("Joining by: ", capture.output(dput(by)))
by

list(
x = by,
y = by
)
}

unique_names <- function(x_names, y_names, by, x_suffix = ".x", y_suffix = ".y") {
Expand Down
2 changes: 1 addition & 1 deletion R/src-sqlite.r
Expand Up @@ -100,7 +100,7 @@ src_sqlite <- function(path, create = FALSE) {
stop("Path does not exist and create = FALSE", call. = FALSE)
}

con <- dbConnect(RSQLite::SQLite(), dbname = path)
con <- dbConnect(RSQLite::SQLite(), path)
RSQLite.extfuns::init_extensions(con)

info <- dbGetInfo(con)
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-joins.r
Expand Up @@ -173,11 +173,11 @@ test_that("indices don't get mixed up when nrow(x) > nrow(y). #365",{
test_that("join functions error on column not found #371", {
expect_error(
left_join(data.frame(x=1:5), data.frame(y=1:5), by="x"),
"cannot join on column 'x'"
"cannot join on columns 'x'"
)
expect_error(
left_join(data.frame(x=1:5), data.frame(y=1:5), by="y"),
"cannot join on column 'y'"
"cannot join on columns 'y'"
)
expect_error(
left_join(data.frame(x=1:5), data.frame(y=1:5)),
Expand Down

0 comments on commit 9cf29da

Please sign in to comment.