Skip to content

Commit

Permalink
fix: mixing string and numerical comparisons would expose NumpyDispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Nov 24, 2020
1 parent 0c54dd8 commit a44b18f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
21 changes: 6 additions & 15 deletions packages/vaex-core/vaex/arrow/numpy_dispatch.py
Expand Up @@ -69,15 +69,6 @@ def arrow_array(self):
self._arrow_array = vaex.array_types.to_arrow(self._numpy_array)
return self._arrow_array

def __eq__(self, rhs):
if vaex.array_types.is_string(self.arrow_array):
# this does not support scalar input
# return pc.equal(self.arrow_array, rhs)
return NumpyDispatch(pa.array(vaex.functions.str_equals(self.arrow_array, rhs)))
else:
if isinstance(rhs, NumpyDispatch):
rhs = rhs.numpy_array
return NumpyDispatch(pa.array(self.numpy_array == rhs))

for op in _binary_ops:
def closure(op=op):
Expand All @@ -88,16 +79,18 @@ def operator(a, b):
a_data = a.numpy_array
if isinstance(b, NumpyDispatch):
b_data = b.numpy_array
result_data = op['op'](a_data, b_data)
if op['name'] == 'eq' and (vaex.array_types.is_string(a_data) or vaex.array_types.is_string(b_data)):
result_data = vaex.functions.str_equals(a_data, b_data)
else:
result_data = op['op'](a_data, b_data)
if isinstance(a, NumpyDispatch):
result_data = a.add_missing(result_data)
if isinstance(b, NumpyDispatch):
result_data = b.add_missing(result_data)
return NumpyDispatch(result_data)
return operator
method_name = '__%s__' % op['name']
if op['name'] != "eq":
setattr(NumpyDispatch, method_name, closure())
setattr(NumpyDispatch, method_name, closure())
# to support e.g. (1 + ...) # to support e.g. (1 + ...)
if op['name'] in reversable:
def closure(op=op):
Expand Down Expand Up @@ -154,7 +147,5 @@ def wrapper(*args, **kwargs):
args = list(map(unwrap, args))
kwargs = {k: unwrap(v) for k, v, in kwargs.items()}
result = f(*args, **kwargs)
if isinstance(result, vaex.array_types.supported_arrow_array_types):
result = NumpyDispatch(result)
return result
return wrap(result)
return wrapper
19 changes: 15 additions & 4 deletions tests/compute_test.py
Expand Up @@ -17,6 +17,11 @@ def y(array_factory2):
return array_factory2([1, 2, None, None])


@pytest.fixture(scope='session')
def s():
return pa.array(['a', 'b', None, 'd'])


def test_add(x, y):
df = vaex.from_arrays(x=x, y=y)
df['z'] = df.x + df.y
Expand Down Expand Up @@ -47,7 +52,13 @@ def test_stay_same_type(x):
assert (df.x.sin()).tolist()[-1] == None
assert isinstance(df.x.sin().values, type(x))

# def test_values2():
# df = vaex.from_arrays(x=, y=y)
# df = vaex.from_scalars(x=0, y=1)
# assert isinstance(df.evaluate('cos(x)+y'), type(x))

def test_mix_string_and_numeric(x, s):
df = vaex.from_arrays(x=x, s=s)
# TODO: Note that this is a seperate bug, it ignored the missing value
assert (df.s == 'a').tolist() == [True, False, False, False]
assert (df.x == 1).tolist() == [False, True, False, None]
assert ((df.s == 'a') | (df.x == 1)).tolist()[0] is True
assert ((df.s == 'a') | (df.x == 1)).tolist() == [True, True, False, None]
assert (('a' == df.s) | (df.x == 1)).tolist() == [True, True, False, None]
assert ((df.x == 1) | (df.s == 'a')).tolist() == [True, True, False, None]

0 comments on commit a44b18f

Please sign in to comment.