diff --git a/tests/pipeline/test_engine.py b/tests/pipeline/test_engine.py index 339f2f9b34..b4a9e78fc6 100644 --- a/tests/pipeline/test_engine.py +++ b/tests/pipeline/test_engine.py @@ -62,6 +62,7 @@ MaxDrawdown, SimpleMovingAverage, ) +from zipline.pipeline.term import NotSpecified from zipline.testing import ( make_rotating_equity_info, make_simple_equity_info, @@ -95,6 +96,14 @@ def compute(self, today, assets, out, close): out[:] = assets +class OpenPrice(CustomFactor): + window_length = 1 + inputs = [USEquityPricing.open] + + def compute(self, today, assets, out, open): + out[:] = open + + def assert_multi_index_is_product(testcase, index, *levels): """Assert that a MultiIndex contains the product of `*levels`.""" testcase.assertIsInstance( @@ -354,6 +363,62 @@ def test_numeric_factor(self): DataFrame(expected_avg, index=dates, columns=self.assets), ) + def test_masked_factor(self): + """ + Test that a Custom Factor computes the correct values when passed a + mask. The mask/filter should be applied prior to computing any values, + as opposed to computing the factor across the entire universe of + assets. Any assets that are filtered out should be filled with missing + values. + """ + loader = self.loader + dates = self.dates[5:10] + assets = self.assets + asset_ids = self.asset_ids + num_dates = len(dates) + constants = self.constants + open = USEquityPricing.open + engine = SimplePipelineEngine( + lambda column: loader, self.dates, self.asset_finder, + ) + + # These are the expected values for the OpenPrice factor. If we pass + # OpenPrice a mask, any assets that are filtered out should have all + # NaN values. Otherwise, we expect its computed values to be the + # asset's open price. + values = array([constants[open]] * num_dates, dtype=float) + missing_values = array([nan] * num_dates) + + for asset_id in asset_ids: + mask = AssetID() <= asset_id + factor1 = OpenPrice(mask=mask) + + # Test running our pipeline both with and without a second factor. + # We do not explicitly test the resulting values of the second + # factor; we just want to implicitly ensure that the addition of + # another factor to the pipeline term graph does not cause any + # unexpected exceptions when calling `run_pipeline`. + for factor2 in (None, + RollingSumDifference(mask=NotSpecified), + RollingSumDifference(mask=mask)): + if factor2 is None: + columns = {'factor1': factor1} + else: + columns = {'factor1': factor1, 'factor2': factor2} + pipeline = Pipeline(columns=columns) + results = engine.run_pipeline(pipeline, dates[0], dates[-1]) + factor1_results = results['factor1'].unstack() + + expected = { + asset: values if asset.sid <= asset_id else missing_values + for asset in assets + } + + assert_frame_equal( + factor1_results, + DataFrame(expected, index=dates, columns=assets), + ) + def test_rolling_and_nonrolling(self): open_ = USEquityPricing.open close = USEquityPricing.close diff --git a/zipline/pipeline/engine.py b/zipline/pipeline/engine.py index c871530db8..39a4776c16 100644 --- a/zipline/pipeline/engine.py +++ b/zipline/pipeline/engine.py @@ -242,8 +242,11 @@ def _mask_and_dates_for_term(self, term, workspace, graph, dates): Load mask and mask row labels for term. """ mask = term.mask - offset = graph.extra_rows[mask] - graph.extra_rows[term] - return workspace[mask][offset:], dates[offset:] + mask_offset = graph.extra_rows[mask] - graph.extra_rows[term] + dates_offset = ( + graph.extra_rows[self._root_mask_term] - graph.extra_rows[term] + ) + return workspace[mask][mask_offset:], dates[dates_offset:] @staticmethod def _inputs_for_term(term, workspace, graph): diff --git a/zipline/pipeline/graph.py b/zipline/pipeline/graph.py index 6bcbe8c045..3cf4a62e5c 100644 --- a/zipline/pipeline/graph.py +++ b/zipline/pipeline/graph.py @@ -104,7 +104,8 @@ def offset(self): zipline.pipeline.engine.SimplePipelineEngine._inputs_for_term zipline.pipeline.engine.SimplePipelineEngine._mask_and_dates_for_term """ - return {(term, dep): self.extra_rows[dep] - term.extra_input_rows + return {(term, dep): self.extra_rows[dep] - max(term.extra_input_rows, + self.extra_rows[term]) for term in self for dep in term.dependencies} diff --git a/zipline/pipeline/mixins.py b/zipline/pipeline/mixins.py index 4e8a27b5e3..da9297b88d 100644 --- a/zipline/pipeline/mixins.py +++ b/zipline/pipeline/mixins.py @@ -70,6 +70,7 @@ class CustomTermMixin(object): def __new__(cls, inputs=NotSpecified, window_length=NotSpecified, + mask=NotSpecified, dtype=NotSpecified, missing_value=NotSpecified, **kwargs): @@ -88,6 +89,7 @@ def __new__(cls, cls, inputs=inputs, window_length=window_length, + mask=mask, dtype=dtype, missing_value=missing_value, **kwargs @@ -104,7 +106,6 @@ def _compute(self, windows, dates, assets, mask): Call the user's `compute` function on each window with a pre-built output array. """ - # TODO: Make mask available to user's `compute`. compute = self.compute missing_value = self.missing_value params = self.params @@ -113,14 +114,18 @@ def _compute(self, windows, dates, assets, mask): # TODO: Consider pre-filtering columns that are all-nan at each # time-step? for idx, date in enumerate(dates): + col_mask = mask[idx] + masked_out = out[idx][col_mask] + masked_assets = assets[col_mask] + compute( date, - assets, - out[idx], - *(next(w) for w in windows), + masked_assets, + masked_out, + *(next(w)[:, col_mask] for w in windows), **params ) - out[~mask] = missing_value + out[idx][col_mask] = masked_out return out def short_repr(self):