Skip to content

Commit

Permalink
refactor(core): now that we support masked arrays for strings, simpli…
Browse files Browse the repository at this point in the history
…fy str_equals
  • Loading branch information
maartenbreddels committed Oct 8, 2019
1 parent c0e8da8 commit fd3402e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 25 deletions.
60 changes: 45 additions & 15 deletions packages/vaex-core/src/strings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,32 @@ class StringSequenceBase : public StringSequence {
auto m = matches.mutable_unchecked<1>();
{
py::gil_scoped_release release;
for(size_t i = 0; i < length; i++) {
#if defined(_MSC_VER)
auto str = get(i);
bool match = str == other;
#else
auto str = view(i);
bool match = str == other;
#endif
m(i) = match;
if(has_null()){
for(size_t i = 0; i < length; i++) {
if(is_null(i)) {
m(i) = false;
} else {
#if defined(_MSC_VER)
auto str = get(i);
bool match = str == other;
#else
auto str = view(i);
bool match = str == other;
#endif
m(i) = match;
}
}
} else {
for(size_t i = 0; i < length; i++) {
#if defined(_MSC_VER)
auto str = get(i);
bool match = str == other;
#else
auto str = view(i);
bool match = str == other;
#endif
m(i) = match;
}
}
}
return std::move(matches);
Expand All @@ -328,11 +345,24 @@ class StringSequenceBase : public StringSequence {
auto m = matches.mutable_unchecked<1>();
{
py::gil_scoped_release release;
for(size_t i = 0; i < length; i++) {
auto str = view(i);
auto other = others->view(i);
bool match = str == other;
m(i) = match;
if(has_null() || others->has_null()) {
for(size_t i = 0; i < length; i++) {
if(is_null(i) || others->is_null(i)) {
m(i) = false;
} else {
auto str = view(i);
auto other = others->view(i);
bool match = str == other;
m(i) = match;
}
}
} else {
for(size_t i = 0; i < length; i++) {
auto str = view(i);
auto other = others->view(i);
bool match = str == other;
m(i) = match;
}
}
}
return std::move(matches);
Expand Down Expand Up @@ -1611,7 +1641,7 @@ class StringArray : public StringSequenceBase {
utf8_objects[i] = PyUnicode_AsUTF8String(object_array[i]);
sizes[i] = PyString_Size(utf8_objects[i]);
strings[i] = PyString_AsString(utf8_objects[i]);
} else if(PyString_CheckExact(object_array[i])) {
} else if(PyString_CheckExact(object_array[i]) && ((byte_mask == nullptr) || (byte_mask[i] == 0))) {
// otherwise directly use
utf8_objects[i] = 0;
sizes[i] = PyString_Size(object_array[i]);
Expand Down
10 changes: 0 additions & 10 deletions packages/vaex-core/vaex/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,21 +796,11 @@ def str_equals(x, y):
"""
xmask = None
ymask = None
if np.ma.isMaskedArray(x):
x, xmask = x.data, np.ma.getmaskarray(x)
if np.ma.isMaskedArray(y):
y, ymask = x.data, np.ma.getmaskarray(y)

if not isinstance(x, six.string_types):
x = _to_string_sequence(x)
if not isinstance(y, six.string_types):
y = _to_string_sequence(y)
equals_mask = x.equals(y)
# take out masked values
if xmask is not None:
equals_mask = equals_mask & ~xmask
if ymask is not None:
equals_mask = equals_mask & ~ymask
return equals_mask


Expand Down
2 changes: 2 additions & 0 deletions tests/internal/strings_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def test_masked_array():
mask = np.array([False, False, True, False, True], dtype=bool)
sa = vaex.strings.StringArray(ar, mask)
assert sa.tolist() == ['dog', 'dog', None, 'cat', None]
assert sa.equals('cat').tolist() == [False, False, False, True, False]
assert sa.equals(sa).tolist() == [True, True, False, True, False]

def test_string_array():
ar = np.array(["aap", "noot", None, "mies"], dtype='object')
Expand Down

0 comments on commit fd3402e

Please sign in to comment.