Skip to content

Commit

Permalink
ENH: Add SpecificAssets filter.
Browse files Browse the repository at this point in the history
Adds a filter that matches a set of assets.  Mainly useful for testing
and debugging.
  • Loading branch information
Scott Sanderson committed Oct 7, 2016
1 parent 78de318 commit f874e67
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
45 changes: 42 additions & 3 deletions tests/pipeline/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from numpy.random import randn, seed as random_seed

from zipline.errors import BadPercentileBounds
from zipline.pipeline import Filter, Factor
from zipline.pipeline import Filter, Factor, Pipeline
from zipline.pipeline.classifiers import Classifier
from zipline.pipeline.factors import CustomFactor
from zipline.pipeline.filters import All, Any, AtLeastN
from zipline.testing import parameter_space, permute_rows
from zipline.pipeline.filters import All, Any, AtLeastN, SpecificAssets
from zipline.testing import parameter_space, permute_rows, ZiplineTestCase
from zipline.testing.fixtures import WithSeededRandomPipelineEngine
from zipline.testing.predicates import assert_equal
from zipline.utils.numpy_utils import float64_dtype, int64_dtype
from .base import BasePipelineTestCase, with_default_shape

Expand Down Expand Up @@ -820,3 +822,40 @@ def test_top_and_bottom_with_groupby_and_mask(self, dtype, seed):
},
mask=self.build_mask(permute(rot90(self.eye_mask(shape=shape)))),
)


class SpecificAssetsTestCase(WithSeededRandomPipelineEngine,
ZiplineTestCase):

ASSET_FINDER_EQUITY_SIDS = tuple(range(10))

def test_specific_assets(self):
assets = self.asset_finder.retrieve_all(self.ASSET_FINDER_EQUITY_SIDS)

class SidFactor(CustomFactor):
"""A factor that just returns each asset's sid."""
inputs = ()
window_length = 1

def compute(self, today, sids, out):
out[:] = sids

pipe = Pipeline(
columns={
'sid': SidFactor(),
'evens': SpecificAssets(assets[::2]),
'odds': SpecificAssets(assets[1::2]),
'first_five': SpecificAssets(assets[:5]),
'last_three': SpecificAssets(assets[-3:]),
},
)

start, end = self.trading_days[[-10, -1]]
results = self.run_pipeline(pipe, start, end).unstack()

sids = results.sid.astype(int64_dtype)

assert_equal(results.evens, ~(sids % 2).astype(bool))
assert_equal(results.odds, (sids % 2).astype(bool))
assert_equal(results.first_five, sids < 5)
assert_equal(results.last_three, sids >= 7)
2 changes: 2 additions & 0 deletions zipline/pipeline/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NumExprFilter,
PercentileFilter,
SingleAsset,
SpecificAssets,
)
from .smoothing import All, Any, AtLeastN

Expand All @@ -24,4 +25,5 @@
'NumExprFilter',
'PercentileFilter',
'SingleAsset',
'SpecificAssets',
]
23 changes: 22 additions & 1 deletion zipline/pipeline/filters/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from itertools import chain
from operator import attrgetter

import numpy as np
from numpy import (
float64,
nan,
nanpercentile,
)
import pandas as pd

from zipline.errors import (
BadPercentileBounds,
Expand All @@ -32,7 +34,7 @@
SingleInputMixin,
)
from zipline.pipeline.term import ComputableTerm, Term
from zipline.utils.input_validation import expect_types
from zipline.utils.input_validation import coerce_types, expect_types
from zipline.utils.memoize import classlazyval
from zipline.utils.numpy_utils import bool_dtype, repeat_first_axis

Expand Down Expand Up @@ -494,3 +496,22 @@ def _compute(self, arrays, dates, assets, mask):
asset=self._asset, start_date=dates[0], end_date=dates[-1],
)
return out


class SpecificAssets(Filter):
"""
A Filter that computes True for a specific set of predetermined assets.
"""
inputs = ()
window_length = 0
params = ('sids',)

@expect_types(assets=(list, tuple, np.ndarray))
@coerce_types(assets=((list, np.ndarray, pd.Series), list))
def __new__(cls, assets):
sids = frozenset(asset.sid for asset in assets)
return super(SpecificAssets, cls).__new__(cls, sids=sids)

def _compute(self, arrays, dates, sids, mask):
my_columns = sids.isin(self.params['sids'])
return repeat_first_axis(my_columns, len(mask)) & mask

0 comments on commit f874e67

Please sign in to comment.