From 9c6c8f62e3f6ddb870e9cb965a9de358c2abe23b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Romain=20Fran=C3=A7ois?= Date: Wed, 26 Mar 2014 10:32:01 +0100 Subject: [PATCH] more dispatch for join visitor. (logical <-> integer), (factor <-> character). #228 --- inst/include/dplyr/JoinVisitorImpl.h | 112 +++++++++++++++++++++++++++ inst/tests/test-joins.r | 7 +- src/join.cpp | 24 +++++- 3 files changed, 140 insertions(+), 3 deletions(-) diff --git a/inst/include/dplyr/JoinVisitorImpl.h b/inst/include/dplyr/JoinVisitorImpl.h index b32e40116d..8c22f45124 100644 --- a/inst/include/dplyr/JoinVisitorImpl.h +++ b/inst/include/dplyr/JoinVisitorImpl.h @@ -111,6 +111,118 @@ namespace dplyr{ } } ; + class JoinFactorStringVisitor : public JoinVisitor { + public: + JoinFactorStringVisitor( const IntegerVector& left_, const CharacterVector& right_ ) : + left_ptr(left_.begin()), + left_factor_ptr(Rcpp::internal::r_vector_start(left_.attr("levels")) ), + right_ptr(Rcpp::internal::r_vector_start(right_)) + {} + + inline size_t hash(int i){ + return string_hash( get(i) ) ; + } + + inline bool equal( int i, int j){ + return get(i) == get(j) ; + } + + inline void print(int i){ + Rcpp::Rcout << get(i) << std::endl ; + } + + inline SEXP subset( const std::vector& indices ) { + int n = indices.size() ; + CharacterVector res(n) ; + for( int i=0; i& set ) { + int n = set.size() ; + CharacterVector res(n) ; + VisitorSetIndexSet::const_iterator it=set.begin() ; + for( int i=0; i string_hash ; + + inline SEXP get(int i){ + if( i>=0 ){ + if( left_ptr[i] == NA_INTEGER ) return NA_STRING ; + return left_factor_ptr[ left_ptr[i] - 1 ] ; + } else { + return right_ptr[ -i-1 ] ; + } + } + + } ; + + class JoinStringFactorVisitor : public JoinVisitor { + public: + JoinStringFactorVisitor( const CharacterVector& left_, const IntegerVector& right_ ) : + right_ptr(right_.begin()), + right_factor_ptr(Rcpp::internal::r_vector_start(right_.attr("levels")) ), + left_ptr(Rcpp::internal::r_vector_start(left_)) + {} + + inline size_t hash(int i){ + return string_hash( get(i) ) ; + } + + inline bool equal( int i, int j){ + return get(i) == get(j) ; + } + + inline void print(int i){ + Rcpp::Rcout << get(i) << std::endl ; + } + + inline SEXP subset( const std::vector& indices ) { + int n = indices.size() ; + CharacterVector res(n) ; + for( int i=0; i& set ) { + int n = set.size() ; + CharacterVector res(n) ; + VisitorSetIndexSet::const_iterator it=set.begin() ; + for( int i=0; i string_hash ; + + inline SEXP get(int i){ + if( i>=0 ){ + return left_ptr[i] ; + } else { + int index = -i-1 ; + if( right_ptr[index] == NA_INTEGER ) return NA_STRING ; + return right_factor_ptr[ index - 1 ] ; + } + } + + } ; + + class JoinFactorFactorVisitor : public JoinVisitorImpl { public: typedef JoinVisitorImpl Parent ; diff --git a/inst/tests/test-joins.r b/inst/tests/test-joins.r index df38fbf08e..12ca6a5737 100644 --- a/inst/tests/test-joins.r +++ b/inst/tests/test-joins.r @@ -153,5 +153,10 @@ test_that("join handles type promotions #123", { res <- semi_join(df, match) expect_equal( res$V2, 3:4 ) expect_equal( res$V3, c(103L, 109L) ) - + + df1 <- data.frame( a = c("a", "b" ), b = 1:2, stringsAsFactors = TRUE ) + df2 <- data.frame( a = c("a", "b" ), c = 4:5, stringsAsFactors = FALSE ) + res <- semi_join( df1, df2 ) + res <- semi_join( df2, df1 ) + }) diff --git a/src/join.cpp b/src/join.cpp index f279db1257..4ff08c2420 100644 --- a/src/join.cpp +++ b/src/join.cpp @@ -255,10 +255,16 @@ namespace dplyr{ if( lhs_factor ){ incompatible_join_visitor(left, right, name) ; } else { - // return JoinVisitorImpl( left, right) ; + return new JoinVisitorImpl( left, right) ; } break ; } + case STRSXP: + { + if( lhs_factor ){ + return new JoinFactorStringVisitor( left, right ); + } + } default: break ; } break ; @@ -323,7 +329,21 @@ namespace dplyr{ } case STRSXP: { - return new JoinVisitorImpl ( left, right ) ; + switch( TYPEOF(right) ){ + case INTSXP: + { + if( Rf_inherits(right, "factor" ) ){ + return new JoinStringFactorVisitor( left, right ) ; + } + break ; + } + case STRSXP: + { + return new JoinVisitorImpl ( left, right ) ; + } + default: break ; + } + break ; } default: break ; }