Skip to content

Commit

Permalink
Use R function match() for joining character vectors (#2451)
Browse files Browse the repository at this point in the history
* use r_match() instead of match() for joining characters or factors

* enhance tests to treat factors

* add test

* different values

* deal with warning messages

* fix test

* character_vector_equal()

* same_levels() respects encoding

* fix tests

* add issue number

* also loop over native/UTF-8 combinations
  • Loading branch information
krlmlr committed Feb 20, 2017
1 parent 3e3581c commit 7cc62a9
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 18 deletions.
13 changes: 7 additions & 6 deletions inst/include/dplyr/JoinVisitorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define dplyr_JoinVisitorImpl_H

#include <tools/utils.h>
#include <tools/match.h>

#include <dplyr/comparisons.h>
#include <dplyr/comparisons_different.h>
Expand Down Expand Up @@ -133,8 +134,8 @@ namespace dplyr {
left_levels(get_levels(left)),
right_levels(get_levels(right)),
uniques(get_uniques(left_levels, right_levels)),
left_match(match(left_levels, uniques)),
right_match(match(right_levels, uniques))
left_match(r_match(left_levels, uniques)),
right_match(r_match(right_levels, uniques))
{}

inline size_t hash(int i) {
Expand Down Expand Up @@ -179,8 +180,8 @@ namespace dplyr {
JoinStringStringVisitor(CharacterVector left_, CharacterVector right) :
left(left_),
uniques(get_uniques(left, right)),
i_left(match(left, uniques)),
i_right(match(right, uniques)),
i_left(r_match(left, uniques)),
i_right(r_match(right, uniques)),
int_visitor(i_left, i_right),
p_uniques(internal::r_vector_start<STRSXP>(uniques)),
p_left(i_left.begin()),
Expand Down Expand Up @@ -238,7 +239,7 @@ namespace dplyr {
uniques(get_uniques(get_levels(left), right)),
p_uniques(internal::r_vector_start<STRSXP>(uniques)),

i_right(match(right, uniques)),
i_right(r_match(right, uniques)),
int_visitor(left, i_right)

{}
Expand Down Expand Up @@ -295,7 +296,7 @@ namespace dplyr {
i_right(right_),
uniques(get_uniques(get_levels(i_right), left_)),
p_uniques(internal::r_vector_start<STRSXP>(uniques)),
i_left(match(left_, uniques)),
i_left(r_match(left_, uniques)),

int_visitor(i_left, i_right)
{}
Expand Down
5 changes: 2 additions & 3 deletions inst/include/dplyr/SubsetVectorVisitorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,16 @@ namespace dplyr {
}

private:

inline bool same_levels(SubsetFactorVisitor* other, std::stringstream& ss, const std::string& name) const {
CharacterVector levels_other = other->levels;
if (levels.length() != levels_other.length() || !all(levels == levels_other).is_true()) {

if (!character_vector_equal(levels, levels_other)) {
ss << "Factor levels not equal for column '" << name << "'";
return false;
}
return true;
}


inline SEXP promote(IntegerVector x) const {
copy_most_attributes(x, vec);
return x;
Expand Down
1 change: 1 addition & 0 deletions inst/include/tools/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ CharacterVector get_class(SEXP x);
SEXP set_class(SEXP x, const CharacterVector& class_);
CharacterVector get_levels(SEXP x);
SEXP set_levels(SEXP x, const CharacterVector& levels);
bool character_vector_equal(const CharacterVector& x, const CharacterVector& y);
bool same_levels(SEXP left, SEXP right);

// effectively the same as copy_attributes but without names and dims
Expand Down
30 changes: 21 additions & 9 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,28 @@ SEXP set_levels(SEXP x, const CharacterVector& levels) {
return Rf_setAttrib(x, R_LevelsSymbol, levels);
}

bool same_levels(SEXP left, SEXP right) {
CharacterVector levels_left = get_levels(left);
CharacterVector levels_right = get_levels(right);
if ((SEXP)levels_left == (SEXP)levels_right) return true;
int n = levels_left.size();
if (n != levels_right.size()) return false;

for (int i=0; i<n; i++) {
if (levels_right[i] != levels_left[i]) return false;
bool character_vector_equal(const CharacterVector& x, const CharacterVector& y) {
if ((SEXP)x == (SEXP)y) return true;

if (x.length() != y.length())
return false;

for (R_xlen_t i = 0; i < x.length(); ++i) {
SEXP xi = x[i];
SEXP yi = y[i];

// Ideally we'd use Rf_Seql(), but this is not exported.
if (Rf_NonNullStringMatch(xi, yi)) continue;
if (xi == NA_STRING && yi == NA_STRING) continue;
if (xi == NA_STRING || yi == NA_STRING)
return false;
if (CHAR(xi)[0] == 0 && CHAR(yi)[0] == 0) continue;
return false;
}

return true;
}

bool same_levels(SEXP left, SEXP right) {
return character_vector_equal(get_levels(left), get_levels(right));
}
66 changes: 66 additions & 0 deletions tests/testthat/test-joins.r
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,72 @@ test_that("inner join not crashing (#1559)", {
for (i in 2:100) expect_equal(res[, 1], res[, i])
})

test_that("join handles mix of encodings in data (#1885, #2118, #2271)", {
with_non_utf8_encoding({
special <- get_native_lang_string()

for (factor1 in c(FALSE, TRUE)) {
for (factor2 in c(FALSE, TRUE)) {
for (encoder1 in c(enc2native, enc2utf8)) {
for (encoder2 in c(enc2native, enc2utf8)) {
df1 <- data.frame(x = encoder1(special), y = 1, stringsAsFactors = factor1)
df1 <- tbl_df(df1)
df2 <- data.frame(x = encoder2(special), z = 2, stringsAsFactors = factor2)
df2 <- tbl_df(df2)
df <- data.frame(x = special, y = 1, z = 2, stringsAsFactors = factor1 && factor2)
df <- tbl_df(df)

info <- paste(
factor1,
factor2,
Encoding(as.character(df1$x)),
Encoding(as.character(df2$x))
)

if (factor1 != factor2) warning_msg <- "coercing"
else warning_msg <- NA

expect_warning_msg <- function(code, msg = warning_msg) {
expect_warning(
code, msg,
info = paste(deparse(substitute(code)[[2]][[1]]), info))
}

expect_equal_df <- function(code, df_ = df) {
code <- substitute(code)
eval(bquote(
expect_equal(
.(code), df_,
info = paste(deparse(code[[1]]), info)
)
))
}

expect_warning_msg(expect_equal_df(inner_join(df1, df2, by = "x")))
expect_warning_msg(expect_equal_df(left_join(df1, df2, by = "x")))
expect_warning_msg(expect_equal_df(right_join(df1, df2, by = "x")))
expect_warning_msg(expect_equal_df(full_join(df1, df2, by = "x")))
expect_warning_msg(
expect_equal_df(
semi_join(df1, df2, by = "x"),
data.frame(x = special, y = 1, stringsAsFactors = factor1)
),
msg = NA
)
expect_warning_msg(
expect_equal_df(
anti_join(df1, df2, by = "x"),
data.frame(x = special, y = 1, stringsAsFactors = factor1)[0,]
),
msg = NA
)
}
}
}
}
})
})

test_that("left_join handles mix of encodings in column names (#1571)", {
with_non_utf8_encoding({
special <- get_native_lang_string()
Expand Down

0 comments on commit 7cc62a9

Please sign in to comment.