Skip to content

Commit

Permalink
Merge pull request #287 from benjello/fix-gh-218
Browse files Browse the repository at this point in the history
Fix #218
  • Loading branch information
robbmcleod committed Sep 29, 2017
2 parents 6af5555 + 08cd904 commit 022d4a4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 1 addition & 2 deletions numexpr/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ def function(*args):
@ophelper
def where_func(a, b, c):
if isinstance(a, ConstantNode):
#FIXME: This prevents where(True, a, b)
raise ValueError("too many dimensions")
return b if a.value else c
if allConstantNodes([a, b, c]):
return ConstantNode(numpy.where(a, b, c))
return FuncNode('where', [a, b, c])
Expand Down
10 changes: 10 additions & 0 deletions numexpr/tests/test_numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,16 @@ def test_str_contains_long_needle(self):
b = b'a' * 40
res = evaluate('contains(a, b)')
assert_equal(res, True)

def test_where_scalar_bool(self):
a = True
b = array([1, 2])
c = array([3, 4])
res = evaluate('where(a, b, c)')
assert_array_equal(res, b)
a = False
res = evaluate('where(a, b, c)')
assert_array_equal(res, c)


class test_numexpr2(test_numexpr):
Expand Down

0 comments on commit 022d4a4

Please sign in to comment.