From 10d1da4472e4c6ff5a905c74fe951e5023ecdb74 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 12 May 2018 08:16:57 +0200 Subject: [PATCH] + nest_join(). closes #3570 --- NAMESPACE | 2 ++ NEWS.md | 1 + R/RcppExports.R | 4 +++ R/join.r | 6 +++++ R/tbl-df.r | 25 ++++++++++++++++++ man/join.Rd | 3 +++ man/join.tbl_df.Rd | 4 +++ src/RcppExports.cpp | 19 ++++++++++++++ src/join_exports.cpp | 60 ++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 124 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 84c99cb9f9..28b2ab02e1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -106,6 +106,7 @@ S3method(mutate_,tbl_df) S3method(n_groups,data.frame) S3method(n_groups,grouped_df) S3method(n_groups,rowwise_df) +S3method(nest_join,tbl_df) S3method(print,BoolResult) S3method(print,all_vars) S3method(print,any_vars) @@ -323,6 +324,7 @@ export(n_distinct) export(n_groups) export(na_if) export(near) +export(nest_join) export(nth) export(ntile) export(num_range) diff --git a/NEWS.md b/NEWS.md index 96802b059b..32642d586e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -16,6 +16,7 @@ - new method `rows()` to get a list of row indices for each group (#3489). - new method `group_data()` (#3489). - joins no longer make lazy grouped data (#3566). +- new `nest_join()` function (#3570). # dplyr 0.7.5.9001 diff --git a/R/RcppExports.R b/R/RcppExports.R index 403f4e5ba0..369a9b7cf6 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -126,6 +126,10 @@ inner_join_impl <- function(x, y, by_x, by_y, aux_x, aux_y, na_match) { .Call(`_dplyr_inner_join_impl`, x, y, by_x, by_y, aux_x, aux_y, na_match) } +nest_join_impl <- function(x, y, by_x, by_y, aux_x, aux_y, na_match, yname) { + .Call(`_dplyr_nest_join_impl`, x, y, by_x, by_y, aux_x, aux_y, na_match, yname) +} + left_join_impl <- function(x, y, by_x, by_y, aux_x, aux_y, na_match) { .Call(`_dplyr_left_join_impl`, x, y, by_x, by_y, aux_x, aux_y, na_match) } diff --git a/R/join.r b/R/join.r index 949d984db6..66e4d41b7f 100644 --- a/R/join.r +++ b/R/join.r @@ -120,6 +120,12 @@ semi_join <- function(x, y, by = NULL, copy = FALSE, ...) { UseMethod("semi_join") } +#' @rdname join +#' @export +nest_join <- function(x, y, by = NULL, copy = FALSE, suffix = c(".x", ".y"), ...) { + UseMethod("nest_join") +} + #' @rdname join #' @export anti_join <- function(x, y, by = NULL, copy = FALSE, ...) { diff --git a/R/tbl-df.r b/R/tbl-df.r index 7e61b1df6b..5038f36a4a 100644 --- a/R/tbl-df.r +++ b/R/tbl-df.r @@ -176,6 +176,31 @@ inner_join.tbl_df <- function(x, y, by = NULL, copy = FALSE, reconstruct_join(out, x, vars) } +#' @export +#' @rdname join.tbl_df +nest_join.tbl_df <- function(x, y, by = NULL, copy = FALSE, + suffix = c(".x", ".y"), ..., + na_matches = pkgconfig::get_config("dplyr::na_matches")) { + y_name <- quo_name(enquo(y)) + check_valid_names(tbl_vars(x)) + check_valid_names(tbl_vars(y)) + by <- common_by(by, x, y) + suffix <- check_suffix(suffix) + na_matches <- check_na_matches(na_matches) + + y <- auto_copy(x, y, copy = copy) + + vars <- join_vars(tbl_vars(x), tbl_vars(y), by, suffix) + by_x <- vars$idx$x$by + by_y <- vars$idx$y$by + aux_x <- vars$idx$x$aux + aux_y <- vars$idx$y$aux + + out <- nest_join_impl(x, y, by_x, by_y, aux_x, aux_y, na_matches, y_name) + out +} + + #' @export #' @rdname join.tbl_df left_join.tbl_df <- function(x, y, by = NULL, copy = FALSE, diff --git a/man/join.Rd b/man/join.Rd index c610df5622..725126281c 100644 --- a/man/join.Rd +++ b/man/join.Rd @@ -7,6 +7,7 @@ \alias{right_join} \alias{full_join} \alias{semi_join} +\alias{nest_join} \alias{anti_join} \title{Join two tbls together} \usage{ @@ -20,6 +21,8 @@ full_join(x, y, by = NULL, copy = FALSE, suffix = c(".x", ".y"), ...) semi_join(x, y, by = NULL, copy = FALSE, ...) +nest_join(x, y, by = NULL, copy = FALSE, suffix = c(".x", ".y"), ...) + anti_join(x, y, by = NULL, copy = FALSE, ...) } \arguments{ diff --git a/man/join.tbl_df.Rd b/man/join.tbl_df.Rd index 7c69052d7f..8ffcdade73 100644 --- a/man/join.tbl_df.Rd +++ b/man/join.tbl_df.Rd @@ -3,6 +3,7 @@ \name{join.tbl_df} \alias{join.tbl_df} \alias{inner_join.tbl_df} +\alias{nest_join.tbl_df} \alias{left_join.tbl_df} \alias{right_join.tbl_df} \alias{full_join.tbl_df} @@ -14,6 +15,9 @@ suffix = c(".x", ".y"), ..., na_matches = pkgconfig::get_config("dplyr::na_matches")) +\method{nest_join}{tbl_df}(x, y, by = NULL, copy = FALSE, suffix = c(".x", + ".y"), ..., na_matches = pkgconfig::get_config("dplyr::na_matches")) + \method{left_join}{tbl_df}(x, y, by = NULL, copy = FALSE, suffix = c(".x", ".y"), ..., na_matches = pkgconfig::get_config("dplyr::na_matches")) diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index b9b3e7e0f0..0d6e987ac7 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -377,6 +377,24 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// nest_join_impl +List nest_join_impl(DataFrame x, DataFrame y, IntegerVector by_x, IntegerVector by_y, IntegerVector aux_x, IntegerVector aux_y, bool na_match, String yname); +RcppExport SEXP _dplyr_nest_join_impl(SEXP xSEXP, SEXP ySEXP, SEXP by_xSEXP, SEXP by_ySEXP, SEXP aux_xSEXP, SEXP aux_ySEXP, SEXP na_matchSEXP, SEXP ynameSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< DataFrame >::type x(xSEXP); + Rcpp::traits::input_parameter< DataFrame >::type y(ySEXP); + Rcpp::traits::input_parameter< IntegerVector >::type by_x(by_xSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type by_y(by_ySEXP); + Rcpp::traits::input_parameter< IntegerVector >::type aux_x(aux_xSEXP); + Rcpp::traits::input_parameter< IntegerVector >::type aux_y(aux_ySEXP); + Rcpp::traits::input_parameter< bool >::type na_match(na_matchSEXP); + Rcpp::traits::input_parameter< String >::type yname(ynameSEXP); + rcpp_result_gen = Rcpp::wrap(nest_join_impl(x, y, by_x, by_y, aux_x, aux_y, na_match, yname)); + return rcpp_result_gen; +END_RCPP +} // left_join_impl DataFrame left_join_impl(DataFrame x, DataFrame y, IntegerVector by_x, IntegerVector by_y, IntegerVector aux_x, IntegerVector aux_y, bool na_match); RcppExport SEXP _dplyr_left_join_impl(SEXP xSEXP, SEXP ySEXP, SEXP by_xSEXP, SEXP by_ySEXP, SEXP aux_xSEXP, SEXP aux_ySEXP, SEXP na_matchSEXP) { @@ -729,6 +747,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dplyr_semi_join_impl", (DL_FUNC) &_dplyr_semi_join_impl, 5}, {"_dplyr_anti_join_impl", (DL_FUNC) &_dplyr_anti_join_impl, 5}, {"_dplyr_inner_join_impl", (DL_FUNC) &_dplyr_inner_join_impl, 7}, + {"_dplyr_nest_join_impl", (DL_FUNC) &_dplyr_nest_join_impl, 8}, {"_dplyr_left_join_impl", (DL_FUNC) &_dplyr_left_join_impl, 7}, {"_dplyr_right_join_impl", (DL_FUNC) &_dplyr_right_join_impl, 7}, {"_dplyr_full_join_impl", (DL_FUNC) &_dplyr_full_join_impl, 7}, diff --git a/src/join_exports.cpp b/src/join_exports.cpp index b92eb4cf8f..bb325012a4 100644 --- a/src/join_exports.cpp +++ b/src/join_exports.cpp @@ -186,6 +186,66 @@ DataFrame inner_join_impl(DataFrame x, DataFrame y, ); } +inline int reverse_index(int i){ + return -i-1; +} + +// [[Rcpp::export]] +List nest_join_impl(DataFrame x, DataFrame y, + IntegerVector by_x, IntegerVector by_y, + IntegerVector aux_x, IntegerVector aux_y, + bool na_match, + String yname +) { + + check_by(by_x); + + typedef VisitorSetIndexMap > Map; + DataFrameJoinVisitors visitors(x, y, by_x, by_y, false, na_match); + Map map(visitors); + + int n_x = x.nrows(), n_y = y.nrows(); + + std::vector indices_x; + std::vector indices_y; + + train_push_back_right(map, n_y); + + List list_col(n_x) ; + + DataFrameSubsetVisitors y_subset_visitors(y, aux_y); + + for (int i = 0; i < n_x; i++) { + Map::iterator it = map.find(i); + if (it != map.end()) { + std::transform(it->second.begin(), it->second.end(), it->second.begin(), reverse_index ); + list_col[i] = y_subset_visitors.subset(it->second, Rf_getAttrib(y, R_ClassSymbol)); + } else { + list_col[i] = y_subset_visitors.subset(EmptySubset(), Rf_getAttrib(y, R_ClassSymbol)); + } + } + + int ncol_x = x.size(); + List out( ncol_x + 1); + CharacterVector names_x = x.names(); + for (int i=0; i(out)) out.attr("groups") = x.attr("groups") ; + + return out; + +} + + + // [[Rcpp::export]] DataFrame left_join_impl(DataFrame x, DataFrame y, IntegerVector by_x, IntegerVector by_y,