From 16760d62e7c8af2fc488a3880913d2e9b848d5a8 Mon Sep 17 00:00:00 2001 From: Lionel Henry Date: Thu, 6 Apr 2017 10:41:43 +0200 Subject: [PATCH] Use flatten_if() with custom predicate to process input --- NAMESPACE | 1 + R/bind.r | 83 +++++--------------------- src/RcppExports.cpp | 4 +- src/bind.cpp | 113 ++++++++++++++++++++++++++++++++++-- tests/testthat/test-binds.R | 10 ++-- 5 files changed, 129 insertions(+), 82 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index c2a1549452..32874fa0d7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -424,3 +424,4 @@ importFrom(tibble,type_sum) importFrom(utils,head) importFrom(utils,tail) useDynLib(dplyr) +useDynLib(dplyr,bind_spliceable) diff --git a/R/bind.r b/R/bind.r index e1119a7677..c3494fcae5 100644 --- a/R/bind.r +++ b/R/bind.r @@ -105,8 +105,9 @@ NULL #' @export #' @rdname bind +#' @useDynLib dplyr bind_spliceable bind_rows <- function(..., .id = NULL) { - x <- discard(list_or_dots(...), is_null) + x <- flatten_if(dots_values(...), bind_spliceable) if (!length(x)) { # Handle corner cases gracefully, but always return a tibble @@ -117,19 +118,14 @@ bind_rows <- function(..., .id = NULL) { } } - for (elt in x) { - if (!is_valid_df(elt) && !is_rowwise_atomic(elt) && !is_null(elt)) { - abort("`...` must only contain data frames and named atomic vectors") - } - } - if (!is_null(.id)) { if (!(is_string(.id))) { bad_args(".id", "must be a scalar string, ", "not {type_of(.id)} of length {length(.id)}" ) } - if (!is_named(x)) { + if (!all(have_name(x) | map_lgl(x, is_empty))) { + x <- compact(x) names(x) <- seq_along(x) } } @@ -144,22 +140,25 @@ is_df_list <- function(x) { is_list(x) && every(x, inherits, "data.frame") } -#' @export -rbind.tbl_df <- function(..., deparse.level = 1) { - bind_rows(...) -} - #' @export #' @rdname bind bind_cols <- function(...) { - x <- discard(list_or_dots(...), is_null) + x <- flatten_if(dots_values(...), bind_spliceable) out <- cbind_all(x) tibble::repair_names(out) } + +# Can't forward dots directly because rbind() and cbind() evaluate +# them eagerly which prevents them from being captured + +#' @export +rbind.tbl_df <- function(..., deparse.level = 1) { + bind_rows(!!! list(...)) +} #' @export cbind.tbl_df <- function(..., deparse.level = 1) { - bind_cols(...) + bind_cols(!!! list(...)) } #' @export @@ -173,60 +172,6 @@ combine <- function(...) { } } -list_or_dots <- function(...) { - dots <- dots_list(...) - if (!length(dots)) { - return(dots) - } - - # Old versions specified that first argument could be a list of - # dataframeable objects - if (is_list(dots[[1]])) { - dots[[1]] <- map_if(dots[[1]], is_dataframe_like, as_tibble) - } - - # Need to ensure that each component is a data frame or a vector - # wrapped in a list: - dots <- map_if(dots, is_dataframe_like, function(x) list(as_tibble(x))) - dots <- map_if(dots, is_atomic, list) - dots <- map_if(dots, is.data.frame, list) - - unlist(dots, recursive = FALSE) -} - -is_dataframe_like <- function(x) { - if (is_null(x)) - return(FALSE) - - # data frames are not data lists - if (is.data.frame(x)) - return(FALSE) - - # Must be a list - if (!is_list(x)) - return(FALSE) - - # 0 length named list (#1515) - if (!is_null(names(x)) && length(x) == 0) - return(TRUE) - - # With names - if (!is_named(x)) - return(FALSE) - - # Where each element is an 1d vector or list - if (!every(x, is_1d)) - return(FALSE) - - # All of which have the same length - n <- map_int(x, length) - if (any(n != n[1])) - return(FALSE) - - TRUE -} - - # Deprecated functions ---------------------------------------------------- #' @export diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 6c792335c6..03b4910095 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -110,12 +110,12 @@ BEGIN_RCPP END_RCPP } // rbind_list__impl -List rbind_list__impl(Dots dots); +List rbind_list__impl(List dots); RcppExport SEXP dplyr_rbind_list__impl(SEXP dotsSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< Dots >::type dots(dotsSEXP); + Rcpp::traits::input_parameter< List >::type dots(dotsSEXP); rcpp_result_gen = Rcpp::wrap(rbind_list__impl(dots)); return rcpp_result_gen; END_RCPP diff --git a/src/bind.cpp b/src/bind.cpp index 82ef8fd3d6..8dcbf98936 100644 --- a/src/bind.cpp +++ b/src/bind.cpp @@ -41,8 +41,7 @@ class DataFrameAbleVector { std::vector data; }; -template -String get_dot_name(const Dots& dots, int i) { +String get_dot_name(const List& dots, int i) { RObject names = dots.names(); if (Rf_isNull(names)) return ""; return STRING_ELT(names, i); @@ -73,12 +72,113 @@ static int rows_length(SEXP x) { if (Rf_inherits(x, "data.frame")) return df_rows_length(x); + else if (TYPEOF(x) == VECSXP && Rf_length(x) > 0) + return Rf_length(VECTOR_ELT(x, 0)); else return 1; } +static +bool is_vector(SEXP x) { + switch(TYPEOF(x)) { + case LGLSXP: + case INTSXP: + case REALSXP: + case CPLXSXP: + case STRSXP: + case RAWSXP: + case VECSXP: + return true; + default: + return false; + } +} -template -List rbind__impl(Dots dots, SEXP id = R_NilValue) { +static +void outer_vector_check(SEXP x) { + switch(TYPEOF(x)) { + case LGLSXP: + case INTSXP: + case REALSXP: + case CPLXSXP: + case STRSXP: + case RAWSXP: { + if (Rf_getAttrib(x, R_NamesSymbol) != R_NilValue) + break; + stop("`bind_rows()` expects data frames and named atomic vectors"); + } + case VECSXP: { + if (!OBJECT(x) || Rf_inherits(x, "data.frame")) + break; + } + default: + stop("`bind_rows()` expects data frames and named atomic vectors"); + } +} +static +void inner_vector_check(SEXP x, int nrows) { + if (!is_vector(x)) + stop("`bind_rows()` expects data frames and named atomic vectors 2"); + + if (OBJECT(x)) { + if (Rf_inherits(x, "data.frame")) + stop("`bind_rows()` does not support nested data frames"); + if (Rf_inherits(x, "POSIXlt")) + stop("`bind_rows()` does not support POSIXlt columns"); + } + + if (Rf_length(x) != nrows) + stop("incompatible sizes (%d != %s)", nrows, Rf_length(x)); +} + +static +void bind_type_check(SEXP x, int nrows) { + int n = Rf_length(x); + if (n == 0) + return; + + outer_vector_check(x); + + if (TYPEOF(x) == VECSXP) { + for (int i = 0; i < n; i++) + inner_vector_check(VECTOR_ELT(x, i), nrows); + } +} + +bool is_atomic(SEXP x) { + switch(TYPEOF(x)) { + case LGLSXP: + case INTSXP: + case REALSXP: + case CPLXSXP: + case STRSXP: + case RAWSXP: + return true; + default: + return false; + } +} + +extern "C" +bool bind_spliceable(SEXP x) { + if (TYPEOF(x) != VECSXP) + return false; + + if (OBJECT(x)) { + if (Rf_inherits(x, "spliced")) + return true; + else + return false; + } + + for (size_t i = 0; i != Rf_length(x); ++i) { + if (is_atomic(VECTOR_ELT(x, i))) + return false; + } + + return true; +} + +List rbind__impl(List dots, SEXP id = R_NilValue) { int ndata = dots.size(); int n = 0; std::vector chunks; @@ -110,13 +210,14 @@ List rbind__impl(Dots dots, SEXP id = R_NilValue) { SEXP df = chunks[i]; int nrows = df_nrows[i]; + bind_type_check(df, nrows); CharacterVector df_names = enc2native(Rf_getAttrib(df, R_NamesSymbol)); for (int j = 0; j < Rf_length(df); j++) { SEXP source; int offset; - if (Rf_inherits(df, "data.frame")) { + if (TYPEOF(df) == VECSXP) { source = VECTOR_ELT(df, j); offset = 0; } else { @@ -227,7 +328,7 @@ List bind_rows_(List dots, SEXP id = R_NilValue) { } // [[Rcpp::export]] -List rbind_list__impl(Dots dots) { +List rbind_list__impl(List dots) { return rbind__impl(dots); } diff --git a/tests/testthat/test-binds.R b/tests/testthat/test-binds.R index 17ced64d60..76995e5b06 100644 --- a/tests/testthat/test-binds.R +++ b/tests/testthat/test-binds.R @@ -18,8 +18,8 @@ test_that("bind_rows() and bind_cols() err for non-data frames (#2373)", { df1 <- structure(list(x = 1), class = "blah_frame") df2 <- structure(list(x = 1), class = "blah_frame") - expect_error(bind_cols(df1, df2), "cannot coerce") - expect_error(bind_rows(df1, df2), "cannot coerce") + expect_error(bind_cols(df1, df2), "Data-frame-like objects must inherit from class data\\.frame or be plain lists") + expect_error(bind_rows(df1, df2), "expects data frames and named atomic vectors") }) test_that("bind_rows() err for invalid ID", { @@ -118,7 +118,7 @@ test_that("bind_rows ignores NULL", { test_that("bind_rows only accepts data frames or vectors", { ll <- list(1:5, get_env()) - expect_error(bind_rows(ll), "only contain data frames and named atomic vectors") + expect_error(bind_rows(ll), "expects data frames and named atomic vectors") }) test_that("bind_rows handles list columns (#463)", { @@ -470,7 +470,7 @@ test_that("bind_cols infers classes from first result (#1692)", { test_that("bind_rows rejects POSIXlt columns (#1789)", { df <- data_frame(x = Sys.time() + 1:12) df$y <- as.POSIXlt(df$x) - expect_error(bind_rows(df, df), "not supported") + expect_error(bind_rows(df, df), "does not support POSIXlt columns") }) test_that("bind_rows rejects data frame columns (#2015)", { @@ -483,7 +483,7 @@ test_that("bind_rows rejects data frame columns (#2015)", { expect_error( dplyr::bind_rows(df, df), - "Columns of class data.frame not supported", + "`bind_rows()` does not support nested data frames", fixed = TRUE ) })