From b77cfd41d80b52daefe0af4cbaff3fc9c531bdfb Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Mon, 27 Apr 2020 11:20:38 -0400 Subject: [PATCH] TEST: Add explicit tests for Filter.if_else. --- tests/pipeline/base.py | 36 +++++- tests/pipeline/test_filter.py | 229 ++++++++++++++++++++++++++++++++-- 2 files changed, 251 insertions(+), 14 deletions(-) diff --git a/tests/pipeline/base.py b/tests/pipeline/base.py index 19c05a127d..94acd2dea0 100644 --- a/tests/pipeline/base.py +++ b/tests/pipeline/base.py @@ -6,6 +6,7 @@ from pandas import DataFrame, Timestamp from six import iteritems +from zipline.lib.labelarray import LabelArray from zipline.utils.compat import wraps from zipline.pipeline import ExecutionPlan from zipline.pipeline.domain import US_EQUITIES @@ -171,10 +172,43 @@ def arange_data(self, shape, dtype=np.float64): @with_default_shape def randn_data(self, seed, shape): """ - Build a block of testing data from a seeded RandomState. + Build a block of random numerical data. """ return np.random.RandomState(seed).randn(*shape) + @with_default_shape + def rand_ints(self, seed, shape, low=0, high=10): + """ + Build a block of random numerical data. + """ + rand = np.random.RandomState(seed) + return rand.randint(low, high, shape, dtype='i8') + + @with_default_shape + def rand_datetimes(self, seed, shape): + ints = self.rand_ints(seed=seed, shape=shape, low=0, high=10000) + return ints.astype('datetime64[D]').astype('datetime64[ns]') + + @with_default_shape + def rand_categoricals(self, categories, seed, shape, missing_value=None): + """Build a block of random categorical data. + + Categories should not include ``missing_value``. + """ + categories = list(categories) + [missing_value] + data = np.random.RandomState(seed).choice(categories, shape) + return LabelArray( + data, + missing_value=missing_value, + categories=categories, + ) + + @with_default_shape + def rand_mask(self, seed, shape): + """Build a block of random boolean data. + """ + return np.random.RandomState(seed).randint(0, 2, shape).astype(bool) + @with_default_shape def eye_mask(self, shape): """ diff --git a/tests/pipeline/test_filter.py b/tests/pipeline/test_filter.py index 1e333ca406..bf77ad2ec2 100644 --- a/tests/pipeline/test_filter.py +++ b/tests/pipeline/test_filter.py @@ -21,12 +21,14 @@ ones_like, putmask, rot90, - sum as np_sum + sum as np_sum, + where, ) -from numpy.random import choice, randn, seed as random_seed +from numpy.random import RandomState import pandas as pd from zipline.errors import BadPercentileBounds +from zipline.lib.labelarray import labelarray_where from zipline.pipeline import Filter, Factor, Pipeline from zipline.pipeline.classifiers import Classifier from zipline.pipeline.domain import US_EQUITIES @@ -48,7 +50,7 @@ int64_dtype, object_dtype, ) -from .base import BaseUSEquityPipelineTestCase, with_default_shape +from .base import BaseUSEquityPipelineTestCase def rowwise_rank(array, mask=None): @@ -129,14 +131,6 @@ def init_instance_fixtures(self): 'datetime64[ns]': self.datetime_f, } - @with_default_shape - def randn_data(self, seed, shape): - """ - Build a block of testing data from numpy.random.randn. - """ - random_seed(seed) - return randn(*shape) - def test_bad_percentiles(self): f = self.f @@ -460,7 +454,7 @@ class SomeWindowSafeIntFactor(Factor): input_factor = SomeWindowSafeIntFactor() shape = (10, 6) - data = choice(range(1, 5), size=shape, replace=True) + data = RandomState(5).choice(range(1, 5), size=shape, replace=True) data[eye(*shape, dtype=bool)] = input_factor.missing_value expected_3 = array([[1, 0, 0, 0, 1, 1], @@ -504,7 +498,7 @@ class SomeWindowSafeStringClassifier(Classifier): input_factor = SomeWindowSafeStringClassifier() shape = (10, 6) - data = choice( + data = RandomState(6).choice( array(['a', 'e', 'i', 'o', 'u'], dtype=object_dtype), size=shape, replace=True @@ -1129,3 +1123,212 @@ def test_maximum_repr(self): assert_equal(short_rep, "Maximum:\\l " "groupby: SomeClassifier(...)\\l " "mask: SomeFilter(...)\\l") + + +class IfElseTestCase(BaseUSEquityPipelineTestCase, ZiplineTestCase): + + @classmethod + def init_class_fixtures(cls): + super(IfElseTestCase, cls).init_class_fixtures() + cls.assets = cls.asset_finder.retrieve_all( + cls.asset_finder.equities_sids, + ) + + @parameter_space(seed=[1, 2, 3]) + def test_if_then_else_factor(self, seed): + f = SomeFactor() + g = SomeOtherFactor() + cond = SomeFilter() + + f_data = self.randn_data(seed=seed) + g_data = self.randn_data(seed=seed + 1) + cond_data = self.rand_mask(seed=seed + 2) + + workspace = { + f: f_data, + g: g_data, + cond: cond_data, + } + terms = { + 'result': cond.if_else(f, g), + 'result_1d': cond.if_else(f, g[self.assets[0]]), + } + expected = { + 'result': where(cond_data, f_data, g_data), + 'result_1d': where(cond_data, f_data, g_data[:, [0]]), + } + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace=workspace, + mask=self.build_mask(self.ones_mask()), + ) + + @parameter_space(seed=[1000, 2000, 3000]) + def test_if_then_else_datetime_factor(self, seed): + class SomeOtherDatetimeFactor(Factor): + dtype = datetime64ns_dtype + inputs = () + window_length = 0 + + f = SomeDatetimeFactor() + g = SomeOtherDatetimeFactor() + cond = SomeFilter() + + f_data = self.randn_data(seed=seed) + g_data = self.randn_data(seed=seed + 1) + cond_data = self.rand_mask(seed=seed + 2) + + workspace = { + f: f_data, + g: g_data, + cond: cond_data, + } + terms = { + 'result': cond.if_else(f, g), + 'result_1d': cond.if_else(f, g[self.assets[5]]), + } + expected = { + 'result': where(cond_data, f_data, g_data), + 'result_1d': where(cond_data, f_data, g_data[:, [5]]), + } + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace=workspace, + mask=self.build_mask(self.ones_mask()), + ) + + @parameter_space(seed=[10, 11, 12]) + def test_if_then_else_filter(self, seed): + class Filter1(Filter): + inputs = () + window_length = 0 + + class Filter2(Filter): + inputs = () + window_length = 0 + + f = Filter1() + g = Filter2() + cond = SomeFilter() + + f_data = self.rand_mask(seed=seed) + g_data = self.rand_mask(seed=seed + 1) + cond_data = self.rand_mask(seed=seed + 2) + + workspace = { + f: f_data, + g: g_data, + cond: cond_data, + } + terms = { + 'result': cond.if_else(f, g), + 'result_1d': cond.if_else(f, g[self.assets[1]]), + } + expected = { + 'result': where(cond_data, f_data, g_data), + 'result_1d': where(cond_data, f_data, g_data[:, [1]]), + } + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace=workspace, + mask=self.build_mask(self.ones_mask()), + ) + + @parameter_space(seed=[100, 101, 102]) + def test_if_then_else_string_classifier(self, seed): + class Classifier1(Classifier): + inputs = () + window_length = 0 + dtype = object + + class Classifier2(Classifier): + inputs = () + window_length = 0 + dtype = object + + f = Classifier1() + g = Classifier2() + cond = SomeFilter() + + f_data = self.rand_categoricals( + seed=seed, + categories=['a', 'b', 'c'] + ) + g_data = self.rand_categoricals( + seed=seed + 1, + categories=['d', 'e', 'f'], + ) + cond_data = self.rand_mask(seed=seed + 2) + + workspace = { + f: f_data, + g: g_data, + cond: cond_data, + } + + terms = { + 'result': cond.if_else(f, g), + 'result_1d': cond.if_else(f, g[self.assets[2]]), + } + expected = { + 'result': labelarray_where(cond_data, f_data, g_data), + 'result_1d': labelarray_where(cond_data, f_data, g_data[:, [2]]), + } + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace=workspace, + mask=self.build_mask(self.ones_mask()), + ) + + @parameter_space(seed=[200, 300, 400]) + def test_if_then_else_int_classifier(self, seed): + + class Classifier1(Classifier): + inputs = () + window_length = 0 + dtype = int64_dtype + missing_value = -1 + + class Classifier2(Classifier): + inputs = () + window_length = 0 + dtype = int64_dtype + missing_value = -1 + + f = Classifier1() + g = Classifier2() + cond = SomeFilter() + + f_data = self.rand_ints(seed=seed) + g_data = self.rand_ints(seed=seed + 1) + cond_data = self.rand_mask(seed=seed + 2) + + workspace = { + f: f_data, + g: g_data, + cond: cond_data, + } + + terms = { + 'result': cond.if_else(f, g), + 'result_1d': cond.if_else(f, g[self.assets[4]]), + } + expected = { + 'result': where(cond_data, f_data, g_data), + 'result_1d': where(cond_data, f_data, g_data[:, [4]]), + } + + self.check_terms( + terms=terms, + expected=expected, + initial_workspace=workspace, + mask=self.build_mask(self.ones_mask()), + )