Skip to content

Commit

Permalink
Merge f4bad5b into fe00452
Browse files Browse the repository at this point in the history
  • Loading branch information
llllllllll committed Oct 10, 2016
2 parents fe00452 + f4bad5b commit a483b75
Show file tree
Hide file tree
Showing 13 changed files with 429 additions and 46 deletions.
63 changes: 63 additions & 0 deletions tests/pipeline/test_alias.py
@@ -0,0 +1,63 @@
import numpy as np

from zipline.testing.predicates import assert_equal
from zipline.pipeline import Classifier, Factor, Filter
from zipline.utils.numpy_utils import float64_dtype, int64_dtype

from .base import BasePipelineTestCase


class WithAlias(object):

def test_alias(self):
f = self.Term()
alias = f.alias('ayy lmao')

f_values = np.random.RandomState(5).randn(5, 5)

self.check_terms(
terms={
'f_alias': alias,
},
expected={
'f_alias': f_values,
},
initial_workspace={f: f_values},
mask=self.build_mask(np.ones((5, 5))),
)

def test_repr(self):
assert_equal(
repr(self.Term().alias('ayy lmao')),
"Aliased%s(Term(...), name='ayy lmao')" % (
self.Term.__base__.__name__,
),
)

def test_short_repr(self):
for name in ('a', 'b'):
assert_equal(
self.Term().alias(name).short_repr(),
name,
)


class TestFactorAlias(WithAlias, BasePipelineTestCase):
class Term(Factor):
dtype = float64_dtype
inputs = ()
window_length = 0


class TestFilterAlias(WithAlias, BasePipelineTestCase):
class Term(Filter):
inputs = ()
window_length = 0


class TestClassifierAlias(WithAlias, BasePipelineTestCase):
class Term(Classifier):
dtype = int64_dtype
inputs = ()
window_length = 0
missing_value = -1
64 changes: 61 additions & 3 deletions tests/pipeline/test_engine.py
Expand Up @@ -33,7 +33,7 @@
from pandas.compat.chainmap import ChainMap
from pandas.util.testing import assert_frame_equal
from six import iteritems, itervalues
from toolz import merge
from toolz import merge, assoc

from zipline.assets.synthetic import make_rotating_equity_info
from zipline.errors import NoFurtherDataError
Expand Down Expand Up @@ -66,6 +66,7 @@
from zipline.testing import (
AssetID,
AssetIDPlusDay,
ExplodingObject,
check_arrays,
make_alternating_boolean_array,
make_cascading_boolean_array,
Expand Down Expand Up @@ -163,14 +164,14 @@ def compute(self, today, assets, out, *inputs):
out[:] = sum(inputs).sum(axis=0)


class ConstantInputTestCase(WithTradingEnvironment, ZiplineTestCase):
class WithConstantInputs(WithTradingEnvironment):
asset_ids = ASSET_FINDER_EQUITY_SIDS = 1, 2, 3, 4
START_DATE = Timestamp('2014-01-01', tz='utc')
END_DATE = Timestamp('2014-03-01', tz='utc')

@classmethod
def init_class_fixtures(cls):
super(ConstantInputTestCase, cls).init_class_fixtures()
super(WithConstantInputs, cls).init_class_fixtures()
cls.constants = {
# Every day, assume every stock starts at 2, goes down to 1,
# goes up to 4, and finishes at 3.
Expand All @@ -192,6 +193,8 @@ def init_class_fixtures(cls):
)
cls.assets = cls.asset_finder.retrieve_all(cls.asset_ids)


class ConstantInputTestCase(WithConstantInputs, ZiplineTestCase):
def test_bad_dates(self):
loader = self.loader
engine = SimplePipelineEngine(
Expand Down Expand Up @@ -1315,3 +1318,58 @@ def test_string_classifiers_produce_categoricals(self):
columns=self.asset_finder.retrieve_all(self.asset_finder.sids),
)
assert_frame_equal(result.c.unstack(), expected_final_result)


class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
def test_populate_default_workspace(self):
column = USEquityPricing.low
base_term = column.latest
term = (base_term + 1).alias('term')
composed_term = term + 1
column_value = self.constants[column]
precomputed_value = -column_value

def populate_initial_workspace(initial_workspace,
root_mask_term,
execution_plan,
dates,
assets):
return assoc(
initial_workspace,
term,
full(
(len(dates), len(assets)),
precomputed_value,
dtype=float64,
),
)

def dispatcher(column):
if column is base_term:
# the base_term should never be loaded, its initial refcount
# should be zero
return ExplodingObject()
return self.loader

engine = SimplePipelineEngine(
dispatcher,
self.dates,
self.asset_finder,
populate_initial_workspace=populate_initial_workspace,
)

results = engine.run_pipeline(
Pipeline({
'term': term,
'composed_term': composed_term,
}),
self.dates[0],
self.dates[-1],
)

self.assertTrue(
(results['term'] == precomputed_value).all(),
)
self.assertTrue(
(results['composed_term'] == (precomputed_value + 1)).all(),
)
10 changes: 8 additions & 2 deletions tests/pipeline/test_term.py
Expand Up @@ -195,8 +195,14 @@ def check_output(graph):
self.assertIn(SomeDataSet.bar, resolution_order)
self.assertIn(SomeFactor(), resolution_order)

self.assertEqual(graph.node[SomeDataSet.foo]['extra_rows'], 4)
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
self.assertEqual(
graph.graph.node[SomeDataSet.foo]['extra_rows'],
4,
)
self.assertEqual(
graph.graph.node[SomeDataSet.bar]['extra_rows'],
4,
)

for foobar in gen_equivalent_factors():
check_output(self.make_execution_plan(to_dict([foobar])))
Expand Down
23 changes: 23 additions & 0 deletions zipline/lib/labelarray.py
Expand Up @@ -197,6 +197,29 @@ def _from_codes_and_metadata(cls,
ret._missing_value = missing_value
return ret

@classmethod
def from_categorical(cls, categorical, missing_value=None):
"""
Create a LabelArray from a pandas categorical.
Parameters
----------
categorical : pd.Categorical
The categorical object to convert.
missing_value : bytes, unicode, or None, optional
The missing value to use for this LabelArray.
Returns
-------
la : LabelArray
The LabelArray representation of this categorical.
"""
return LabelArray(
categorical,
missing_value,
categorical.categories,
)

@property
def categories(self):
# This is a property because it should be immutable.
Expand Down
21 changes: 21 additions & 0 deletions zipline/pipeline/classifiers/classifier.py
Expand Up @@ -6,6 +6,7 @@
import re

from numpy import where, isnan, nan, zeros
import pandas as pd

from zipline.lib.labelarray import LabelArray
from zipline.lib.quantiles import quantiles
Expand All @@ -23,6 +24,7 @@

from ..filters import ArrayPredicate, NotNullFilter, NullFilter, NumExprFilter
from ..mixins import (
AliasedMixin,
CustomTermMixin,
DownsampledMixin,
LatestMixin,
Expand Down Expand Up @@ -303,10 +305,29 @@ def postprocess(self, data):
raise AssertionError("Expected a LabelArray, got %s." % type(data))
return data.as_categorical()

def to_workspace_value(self, result, assets):
"""
Called with the result of a pipeline. This needs to return an object
which can be put into the workspace to continue doing computations.
This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`.
"""
data = super(Classifier, self).unprocess(result, assets)
if self.dtype == int64_dtype:
return data
assert isinstance(data, pd.Categorical), (
'Expected a Categorical, got %r.' % type(data).__name__
)
return LabelArray.from_categorical(data, self.missing_value)

@classlazyval
def _downsampled_type(self):
return DownsampledMixin.make_downsampled_type(Classifier)

@classlazyval
def _aliased_type(self):
return AliasedMixin.make_aliased_type(Classifier)


class Everything(Classifier):
"""
Expand Down
65 changes: 42 additions & 23 deletions zipline/pipeline/engine.py
Expand Up @@ -81,6 +81,14 @@ def run_pipeline(self, pipeline, start_date, end_date):
)


def _default_populate_initial_workspace(initial_workspace,
root_mask_term,
execution_plan,
dates,
assets):
return initial_workspace


class SimplePipelineEngine(object):
"""
PipelineEngine class that computes each term independently.
Expand All @@ -96,24 +104,39 @@ class SimplePipelineEngine(object):
asset_finder : zipline.assets.AssetFinder
An AssetFinder instance. We depend on the AssetFinder to determine
which assets are in the top-level universe at any point in time.
populate_initial_workspace : callable, optional
A function which will be used to populate the initial workspace when
computing a pipeline. This function will be passed the
initial_workspace, the root mask term, the execution_plan, the dates
being computed for, and the assets requested and should return a new
dictionary which will be used as the initial_workspace.
"""
__slots__ = (
'_get_loader',
'_calendar',
'_finder',
'_root_mask_term',
'_root_mask_dates_term',
'_populate_initial_workspace',
'__weakref__',
)

def __init__(self, get_loader, calendar, asset_finder):
def __init__(self,
get_loader,
calendar,
asset_finder,
populate_initial_workspace=None):
self._get_loader = get_loader
self._calendar = calendar
self._finder = asset_finder

self._root_mask_term = AssetExists()
self._root_mask_dates_term = InputDates()

self._populate_initial_workspace = (
populate_initial_workspace or _default_populate_initial_workspace
)

def run_pipeline(self, pipeline, start_date, end_date):
"""
Compute a pipeline.
Expand Down Expand Up @@ -179,14 +202,22 @@ def run_pipeline(self, pipeline, start_date, end_date):
root_mask = self._compute_root_mask(start_date, end_date, extra_rows)
dates, assets, root_mask_values = explode(root_mask)

initial_workspace = self._populate_initial_workspace(
{
self._root_mask_term: root_mask_values,
self._root_mask_dates_term: as_column(dates.values)
},
self._root_mask_term,
graph,
dates,
assets,
)

results = self.compute_chunk(
graph,
dates,
assets,
initial_workspace={
self._root_mask_term: root_mask_values,
self._root_mask_dates_term: as_column(dates.values)
},
initial_workspace,
)

return self._to_narrow(
Expand Down Expand Up @@ -255,21 +286,6 @@ def _compute_root_mask(self, start_date, end_date, extra_rows):
assert shape[0] * shape[1] != 0, 'root mask cannot be empty'
return ret

def _mask_and_dates_for_term(self, term, workspace, graph, all_dates):
"""
Load mask and mask row labels for term.
"""
mask = term.mask
mask_offset = graph.extra_rows[mask] - graph.extra_rows[term]

# This offset is computed against _root_mask_term because that is what
# determines the shape of the top-level dates array.
dates_offset = (
graph.extra_rows[self._root_mask_term] - graph.extra_rows[term]
)

return workspace[mask][mask_offset:], all_dates[dates_offset:]

@staticmethod
def _inputs_for_term(term, workspace, graph):
"""
Expand Down Expand Up @@ -346,7 +362,7 @@ def compute_chunk(self, graph, dates, assets, initial_workspace):

refcounts = graph.initial_refcounts(workspace)

for term in graph.ordered():
for term in graph.execution_order(refcounts):
# `term` may have been supplied in `initial_workspace`, and in the
# future we may pre-compute loadable terms coming from the same
# dataset. In either case, we will already have an entry for this
Expand All @@ -356,8 +372,11 @@ def compute_chunk(self, graph, dates, assets, initial_workspace):

# Asset labels are always the same, but date labels vary by how
# many extra rows are needed.
mask, mask_dates = self._mask_and_dates_for_term(
term, workspace, graph, dates
mask, mask_dates = graph.mask_and_dates_for_term(
term,
self._root_mask_term,
workspace,
dates,
)

if isinstance(term, LoadableTerm):
Expand Down

0 comments on commit a483b75

Please sign in to comment.