Skip to content

Commit

Permalink
Merge 7bcdbbf into 2fd0f2a
Browse files Browse the repository at this point in the history
  • Loading branch information
dmichalowicz committed Apr 4, 2016
2 parents 2fd0f2a + 7bcdbbf commit 850109a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 8 deletions.
65 changes: 65 additions & 0 deletions tests/pipeline/test_engine.py
Expand Up @@ -62,6 +62,7 @@
MaxDrawdown,
SimpleMovingAverage,
)
from zipline.pipeline.term import NotSpecified
from zipline.testing import (
make_rotating_equity_info,
make_simple_equity_info,
Expand Down Expand Up @@ -95,6 +96,14 @@ def compute(self, today, assets, out, close):
out[:] = assets


class OpenPrice(CustomFactor):
window_length = 1
inputs = [USEquityPricing.open]

def compute(self, today, assets, out, open):
out[:] = open


def assert_multi_index_is_product(testcase, index, *levels):
"""Assert that a MultiIndex contains the product of `*levels`."""
testcase.assertIsInstance(
Expand Down Expand Up @@ -354,6 +363,62 @@ def test_numeric_factor(self):
DataFrame(expected_avg, index=dates, columns=self.assets),
)

def test_masked_factor(self):
"""
Test that a Custom Factor computes the correct values when passed a
mask. The mask/filter should be applied prior to computing any values,
as opposed to computing the factor across the entire universe of
assets. Any assets that are filtered out should be filled with missing
values.
"""
loader = self.loader
dates = self.dates[5:10]
assets = self.assets
asset_ids = self.asset_ids
num_dates = len(dates)
constants = self.constants
open = USEquityPricing.open
engine = SimplePipelineEngine(
lambda column: loader, self.dates, self.asset_finder,
)

# These are the expected values for the OpenPrice factor. If we pass
# OpenPrice a mask, any assets that are filtered out should have all
# NaN values. Otherwise, we expect its computed values to be the
# asset's open price.
values = array([constants[open]] * num_dates, dtype=float)
missing_values = array([nan] * num_dates)

for asset_id in asset_ids:
mask = AssetID() <= asset_id
factor1 = OpenPrice(mask=mask)

# Test running our pipeline both with and without a second factor.
# We do not explicitly test the resulting values of the second
# factor; we just want to implicitly ensure that the addition of
# another factor to the pipeline term graph does not cause any
# unexpected exceptions when calling `run_pipeline`.
for factor2 in (None,
RollingSumDifference(mask=NotSpecified),
RollingSumDifference(mask=mask)):
if factor2 is None:
columns = {'factor1': factor1}
else:
columns = {'factor1': factor1, 'factor2': factor2}
pipeline = Pipeline(columns=columns)
results = engine.run_pipeline(pipeline, dates[0], dates[-1])
factor1_results = results['factor1'].unstack()

expected = {
asset: values if asset.sid <= asset_id else missing_values
for asset in assets
}

assert_frame_equal(
factor1_results,
DataFrame(expected, index=dates, columns=assets),
)

def test_rolling_and_nonrolling(self):
open_ = USEquityPricing.open
close = USEquityPricing.close
Expand Down
7 changes: 5 additions & 2 deletions zipline/pipeline/engine.py
Expand Up @@ -242,8 +242,11 @@ def _mask_and_dates_for_term(self, term, workspace, graph, dates):
Load mask and mask row labels for term.
"""
mask = term.mask
offset = graph.extra_rows[mask] - graph.extra_rows[term]
return workspace[mask][offset:], dates[offset:]
mask_offset = graph.extra_rows[mask] - graph.extra_rows[term]
dates_offset = (
graph.extra_rows[self._root_mask_term] - graph.extra_rows[term]
)
return workspace[mask][mask_offset:], dates[dates_offset:]

@staticmethod
def _inputs_for_term(term, workspace, graph):
Expand Down
3 changes: 2 additions & 1 deletion zipline/pipeline/graph.py
Expand Up @@ -104,7 +104,8 @@ def offset(self):
zipline.pipeline.engine.SimplePipelineEngine._inputs_for_term
zipline.pipeline.engine.SimplePipelineEngine._mask_and_dates_for_term
"""
return {(term, dep): self.extra_rows[dep] - term.extra_input_rows
return {(term, dep): self.extra_rows[dep] - max(term.extra_input_rows,
self.extra_rows[term])
for term in self
for dep in term.dependencies}

Expand Down
15 changes: 10 additions & 5 deletions zipline/pipeline/mixins.py
Expand Up @@ -70,6 +70,7 @@ class CustomTermMixin(object):
def __new__(cls,
inputs=NotSpecified,
window_length=NotSpecified,
mask=NotSpecified,
dtype=NotSpecified,
missing_value=NotSpecified,
**kwargs):
Expand All @@ -88,6 +89,7 @@ def __new__(cls,
cls,
inputs=inputs,
window_length=window_length,
mask=mask,
dtype=dtype,
missing_value=missing_value,
**kwargs
Expand All @@ -104,7 +106,6 @@ def _compute(self, windows, dates, assets, mask):
Call the user's `compute` function on each window with a pre-built
output array.
"""
# TODO: Make mask available to user's `compute`.
compute = self.compute
missing_value = self.missing_value
params = self.params
Expand All @@ -113,14 +114,18 @@ def _compute(self, windows, dates, assets, mask):
# TODO: Consider pre-filtering columns that are all-nan at each
# time-step?
for idx, date in enumerate(dates):
col_mask = mask[idx]
masked_out = out[idx][col_mask]
masked_assets = assets[col_mask]

compute(
date,
assets,
out[idx],
*(next(w) for w in windows),
masked_assets,
masked_out,
*(next(w)[:, col_mask] for w in windows),
**params
)
out[~mask] = missing_value
out[idx][col_mask] = masked_out
return out

def short_repr(self):
Expand Down

0 comments on commit 850109a

Please sign in to comment.