Skip to content

Commit

Permalink
Merge b1894e0 into aae9f84
Browse files Browse the repository at this point in the history
  • Loading branch information
Scott Sanderson committed Aug 17, 2016
2 parents aae9f84 + b1894e0 commit 8a581c2
Show file tree
Hide file tree
Showing 31 changed files with 2,045 additions and 485 deletions.
2 changes: 1 addition & 1 deletion appveyor.yml
Expand Up @@ -101,7 +101,7 @@ install:
- pip freeze | sort

test_script:
- nosetests
- nosetests -e zipline.utils.numpy_utils
- flake8 zipline tests

branches:
Expand Down
72 changes: 41 additions & 31 deletions tests/pipeline/base.py
@@ -1,24 +1,23 @@
"""
Base class for Pipeline API unittests.
Base class for Pipeline API unit tests.
"""
from functools import wraps

import numpy as np
from numpy import arange, prod
from pandas import date_range, Int64Index, DataFrame
from pandas import DataFrame, Timestamp
from six import iteritems

from zipline.assets.synthetic import make_simple_equity_info
from zipline.pipeline.engine import SimplePipelineEngine
from zipline.pipeline import TermGraph
from zipline.pipeline.term import AssetExists
from zipline.pipeline import ExecutionPlan
from zipline.pipeline.term import AssetExists, InputDates
from zipline.testing import (
check_arrays,
ExplodingObject,
tmp_asset_finder,
)
from zipline.testing.fixtures import (
WithTradingCalendars,
WithAssetFinder,
WithTradingSessions,
ZiplineTestCase,
)

Expand Down Expand Up @@ -53,32 +52,26 @@ def method(self, *args, **kwargs):
with_default_shape = with_defaults(shape=lambda self: self.default_shape)


class BasePipelineTestCase(WithTradingCalendars, ZiplineTestCase):
class BasePipelineTestCase(WithTradingSessions,
WithAssetFinder,
ZiplineTestCase):
START_DATE = Timestamp('2014', tz='UTC')
END_DATE = Timestamp('2014-12-31', tz='UTC')
ASSET_FINDER_EQUITY_SIDS = list(range(20))

@classmethod
def init_class_fixtures(cls):
super(BasePipelineTestCase, cls).init_class_fixtures()

cls.__calendar = date_range('2014', '2015',
freq=cls.trading_calendar.day)
cls.__assets = assets = Int64Index(arange(1, 20))
cls.__tmp_finder_ctx = tmp_asset_finder(
equities=make_simple_equity_info(
assets,
cls.__calendar[0],
cls.__calendar[-1],
)
)
cls.__finder = cls.__tmp_finder_ctx.__enter__()
cls.__mask = cls.__finder.lifetimes(
cls.__calendar[-30:],
cls.default_asset_exists_mask = cls.asset_finder.lifetimes(
cls.nyse_sessions[-30:],
include_start_date=False,
)

@property
def default_shape(self):
"""Default shape for methods that build test data."""
return self.__mask.shape
return self.default_asset_exists_mask.shape

def run_graph(self, graph, initial_workspace, mask=None):
"""
Expand All @@ -103,30 +96,47 @@ def run_graph(self, graph, initial_workspace, mask=None):
"""
engine = SimplePipelineEngine(
lambda column: ExplodingObject(),
self.__calendar,
self.__finder,
self.nyse_sessions,
self.asset_finder,
)
if mask is None:
mask = self.__mask
mask = self.default_asset_exists_mask

dates, assets, mask_values = explode(mask)

initial_workspace.setdefault(AssetExists(), mask_values)
initial_workspace.setdefault(InputDates(), dates)

return engine.compute_chunk(
graph,
dates,
assets,
initial_workspace,
)

def check_terms(self, terms, expected, initial_workspace, mask):
def check_terms(self,
terms,
expected,
initial_workspace,
mask,
check=check_arrays):
"""
Compile the given terms into a TermGraph, compute it with
initial_workspace, and compare the results with ``expected``.
"""
graph = TermGraph(terms)
start_date, end_date = mask.index[[0, -1]]
graph = ExecutionPlan(
terms,
all_dates=self.nyse_sessions,
start_date=start_date,
end_date=end_date,
)

results = self.run_graph(graph, initial_workspace, mask)
for key, (res, exp) in dzip_exact(results, expected).items():
check_arrays(res, exp)
check(res, exp)

return results

def build_mask(self, array):
"""
Expand All @@ -138,13 +148,13 @@ def build_mask(self, array):
array,
# Use the **last** N dates rather than the first N so that we have
# space for lookbacks.
index=self.__calendar[-ndates:],
columns=self.__assets[:nassets],
index=self.nyse_sessions[-ndates:],
columns=self.ASSET_FINDER_EQUITY_SIDS[:nassets],
dtype=bool,
)

@with_default_shape
def arange_data(self, shape, dtype=float):
def arange_data(self, shape, dtype=np.float64):
"""
Build a block of testing data from numpy.arange.
"""
Expand Down

0 comments on commit 8a581c2

Please sign in to comment.