diff --git a/docs/source/whatsnew/0.8.4.txt b/docs/source/whatsnew/0.8.4.txt index fe0df9816e..2b86d8f176 100644 --- a/docs/source/whatsnew/0.8.4.txt +++ b/docs/source/whatsnew/0.8.4.txt @@ -130,6 +130,9 @@ Bug Fixes * Fixed issues around KeyErrors coming from history and BarData on 32-bit python, where Assets did not compare properly with int64s (:issue:`959`). +* Fixed a bug where boolean operators were not properly implemented on + :class:~zipline.pipeline.Filter` (:issue:`991`). + Performance ~~~~~~~~~~~ diff --git a/tests/pipeline/test_numerical_expression.py b/tests/pipeline/test_numerical_expression.py index 8dcb73215e..1026a762e6 100644 --- a/tests/pipeline/test_numerical_expression.py +++ b/tests/pipeline/test_numerical_expression.py @@ -1,6 +1,6 @@ +from itertools import permutations from operator import ( add, - and_, ge, gt, le, @@ -8,13 +8,13 @@ methodcaller, mul, ne, - or_, ) from unittest import TestCase import numpy from numpy import ( arange, + array, eye, float64, full, @@ -27,7 +27,7 @@ Int64Index, ) -from zipline.pipeline import Factor +from zipline.pipeline import Factor, Filter from zipline.pipeline.expression import ( NumericalExpression, NUMEXPR_MATH_FUNCS, @@ -55,6 +55,11 @@ class H(Factor): window_length = 0 +class NonExprFilter(Filter): + inputs = () + window_length = 0 + + class DateFactor(Factor): dtype = datetime64ns_dtype inputs = () @@ -460,26 +465,57 @@ def test_comparisons(self): def test_boolean_binops(self): f, g, h = self.f, self.g, self.h + + # Add a non-numexpr filter to ensure that we correctly handle + # delegation to NumericalExpression. + custom_filter = NonExprFilter() + custom_filter_mask = array( + [[0, 1, 0, 1, 0], + [0, 0, 1, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 1, 1, 0], + [0, 0, 0, 1, 0]], + dtype=bool, + ) + self.fake_raw_data = { f: arange(25).reshape(5, 5), g: arange(25).reshape(5, 5) - eye(5), h: full((5, 5), 5), + custom_filter: custom_filter_mask, } # Should be True on the diagonal. - eye_filter = f > g + eye_filter = (f > g) + # Should be True in the first row only. first_row_filter = f < h eye_mask = eye(5, dtype=bool) + first_row_mask = zeros((5, 5), dtype=bool) first_row_mask[0] = 1 self.check_output(eye_filter, eye_mask) self.check_output(first_row_filter, first_row_mask) - for op in (and_, or_): # NumExpr doesn't support xor. - self.check_output( - op(eye_filter, first_row_filter), - op(eye_mask, first_row_mask), - ) + def gen_boolops(x, y, z): + """ + Generate all possible interleavings of & and | between all possible + orderings of x, y, and z. + """ + for a, b, c in permutations([x, y, z]): + yield (a & b) & c + yield (a & b) | c + yield (a | b) & c + yield (a | b) | c + yield a & (b & c) + yield a & (b | c) + yield a | (b & c) + yield a | (b | c) + + exprs = gen_boolops(eye_filter, custom_filter, first_row_filter) + arrays = gen_boolops(eye_mask, custom_filter_mask, first_row_mask) + + for expr, expected in zip(exprs, arrays): + self.check_output(expr, expected) diff --git a/zipline/pipeline/filters/filter.py b/zipline/pipeline/filters/filter.py index fa450254db..f3c02924d0 100644 --- a/zipline/pipeline/filters/filter.py +++ b/zipline/pipeline/filters/filter.py @@ -124,6 +124,13 @@ class Filter(CompositeTerm): for op in FILTER_BINOPS } ) + clsdict.update( + { + method_name_for_op(op, commute=True): binary_operator(op) + for op in FILTER_BINOPS + } + ) + __invert__ = unary_operator('~') def _validate(self):