Skip to content

Commit

Permalink
more dispatch for join visitor. (logical <-> integer), (factor <-> ch…
Browse files Browse the repository at this point in the history
…aracter). #228
  • Loading branch information
romainfrancois committed Mar 26, 2014
1 parent e1b3357 commit 9c6c8f6
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 3 deletions.
112 changes: 112 additions & 0 deletions inst/include/dplyr/JoinVisitorImpl.h
Expand Up @@ -111,6 +111,118 @@ namespace dplyr{
}
} ;

class JoinFactorStringVisitor : public JoinVisitor {
public:
JoinFactorStringVisitor( const IntegerVector& left_, const CharacterVector& right_ ) :
left_ptr(left_.begin()),
left_factor_ptr(Rcpp::internal::r_vector_start<STRSXP>(left_.attr("levels")) ),
right_ptr(Rcpp::internal::r_vector_start<STRSXP>(right_))
{}

inline size_t hash(int i){
return string_hash( get(i) ) ;
}

inline bool equal( int i, int j){
return get(i) == get(j) ;
}

inline void print(int i){
Rcpp::Rcout << get(i) << std::endl ;
}

inline SEXP subset( const std::vector<int>& indices ) {
int n = indices.size() ;
CharacterVector res(n) ;
for( int i=0; i<n; i++) {
res[i] = get(indices[i]) ;
}
return res ;
}
inline SEXP subset( const VisitorSetIndexSet<DataFrameJoinVisitors>& set ) {
int n = set.size() ;
CharacterVector res(n) ;
VisitorSetIndexSet<DataFrameJoinVisitors>::const_iterator it=set.begin() ;
for( int i=0; i<n; i++, ++it) {
res[i] = get(*it) ;
}
return res ;
}

private:
int* left_ptr ;
SEXP* left_factor_ptr ;
SEXP* right_ptr ;
boost::hash<SEXP> string_hash ;

inline SEXP get(int i){
if( i>=0 ){
if( left_ptr[i] == NA_INTEGER ) return NA_STRING ;
return left_factor_ptr[ left_ptr[i] - 1 ] ;
} else {
return right_ptr[ -i-1 ] ;
}
}

} ;

class JoinStringFactorVisitor : public JoinVisitor {
public:
JoinStringFactorVisitor( const CharacterVector& left_, const IntegerVector& right_ ) :
right_ptr(right_.begin()),
right_factor_ptr(Rcpp::internal::r_vector_start<STRSXP>(right_.attr("levels")) ),
left_ptr(Rcpp::internal::r_vector_start<STRSXP>(left_))
{}

inline size_t hash(int i){
return string_hash( get(i) ) ;
}

inline bool equal( int i, int j){
return get(i) == get(j) ;
}

inline void print(int i){
Rcpp::Rcout << get(i) << std::endl ;
}

inline SEXP subset( const std::vector<int>& indices ) {
int n = indices.size() ;
CharacterVector res(n) ;
for( int i=0; i<n; i++) {
res[i] = get(indices[i]) ;
}
return res ;
}
inline SEXP subset( const VisitorSetIndexSet<DataFrameJoinVisitors>& set ) {
int n = set.size() ;
CharacterVector res(n) ;
VisitorSetIndexSet<DataFrameJoinVisitors>::const_iterator it=set.begin() ;
for( int i=0; i<n; i++, ++it) {
res[i] = get(*it) ;
}
return res ;
}

private:
int* right_ptr ;
SEXP* right_factor_ptr ;
SEXP* left_ptr ;
boost::hash<SEXP> string_hash ;

inline SEXP get(int i){
if( i>=0 ){
return left_ptr[i] ;
} else {
int index = -i-1 ;
if( right_ptr[index] == NA_INTEGER ) return NA_STRING ;
return right_factor_ptr[ index - 1 ] ;
}
}

} ;


class JoinFactorFactorVisitor : public JoinVisitorImpl<INTSXP, INTSXP> {
public:
typedef JoinVisitorImpl<INTSXP,INTSXP> Parent ;
Expand Down
7 changes: 6 additions & 1 deletion inst/tests/test-joins.r
Expand Up @@ -153,5 +153,10 @@ test_that("join handles type promotions #123", {
res <- semi_join(df, match)
expect_equal( res$V2, 3:4 )
expect_equal( res$V3, c(103L, 109L) )


df1 <- data.frame( a = c("a", "b" ), b = 1:2, stringsAsFactors = TRUE )
df2 <- data.frame( a = c("a", "b" ), c = 4:5, stringsAsFactors = FALSE )
res <- semi_join( df1, df2 )
res <- semi_join( df2, df1 )

})
24 changes: 22 additions & 2 deletions src/join.cpp
Expand Up @@ -255,10 +255,16 @@ namespace dplyr{
if( lhs_factor ){
incompatible_join_visitor(left, right, name) ;
} else {
// return JoinVisitorImpl<INTSXP, LGLSXP>( left, right) ;
return new JoinVisitorImpl<INTSXP, LGLSXP>( left, right) ;
}
break ;
}
case STRSXP:
{
if( lhs_factor ){
return new JoinFactorStringVisitor( left, right );
}
}
default: break ;
}
break ;
Expand Down Expand Up @@ -323,7 +329,21 @@ namespace dplyr{
}
case STRSXP:
{
return new JoinVisitorImpl<STRSXP,STRSXP> ( left, right ) ;
switch( TYPEOF(right) ){
case INTSXP:
{
if( Rf_inherits(right, "factor" ) ){
return new JoinStringFactorVisitor( left, right ) ;
}
break ;
}
case STRSXP:
{
return new JoinVisitorImpl<STRSXP,STRSXP> ( left, right ) ;
}
default: break ;
}
break ;
}
default: break ;
}
Expand Down

0 comments on commit 9c6c8f6

Please sign in to comment.