Skip to content

Commit

Permalink
TST: Refactors more tests to use WithTradingSchedule
Browse files Browse the repository at this point in the history
  • Loading branch information
jfkirk committed May 6, 2016
1 parent 6bcc749 commit 147fd10
Show file tree
Hide file tree
Showing 22 changed files with 225 additions and 251 deletions.
29 changes: 16 additions & 13 deletions tests/data/test_minute_bars.py
Expand Up @@ -15,8 +15,6 @@
from datetime import timedelta
import os

from unittest import TestCase

from numpy import (
arange,
array,
Expand Down Expand Up @@ -44,7 +42,9 @@
US_EQUITIES_MINUTES_PER_DAY,
BcolzMinuteWriterColumnMismatch
)
from zipline.utils.calendars import get_calendar, default_nyse_schedule
from zipline.utils.calendars import get_calendar

from zipline.testing.fixtures import WithTradingSchedule, ZiplineTestCase

# Calendar is set to cover several half days, to check a case where half
# days would be read out of order in cases of windows which spanned over
Expand All @@ -53,10 +53,11 @@
TEST_CALENDAR_STOP = Timestamp('2015-12-31', tz='UTC')


class BcolzMinuteBarTestCase(TestCase):
class BcolzMinuteBarTestCase(WithTradingSchedule, ZiplineTestCase):

@classmethod
def setUpClass(cls):
def init_class_fixtures(cls):
super(BcolzMinuteBarTestCase, cls).init_class_fixtures()
trading_days = get_calendar('NYSE').trading_days(
TEST_CALENDAR_START, TEST_CALENDAR_STOP
)
Expand All @@ -65,10 +66,15 @@ def setUpClass(cls):
cls.test_calendar_start = cls.market_opens.index[0]
cls.test_calendar_stop = cls.market_opens.index[-1]

def setUp(self):
def dir_cleanup(self):
self.dir_.cleanup()

def init_instance_fixtures(self):
super(BcolzMinuteBarTestCase, self).init_instance_fixtures()

self.dir_ = TempDirectory()
self.dir_.create()
self.add_instance_callback(callback=self.dir_cleanup)
self.dest = self.dir_.getpath('minute_bars')
os.makedirs(self.dest)
self.writer = BcolzMinuteBarWriter(
Expand All @@ -80,9 +86,6 @@ def setUp(self):
)
self.reader = BcolzMinuteBarReader(self.dest)

def tearDown(self):
self.dir_.cleanup()

def test_write_one_ohlcv(self):
minute = self.market_opens[self.test_calendar_start]
sid = 1
Expand Down Expand Up @@ -699,9 +702,9 @@ def test_unadjusted_minutes_early_close(self):
data = {sids[0]: data_1, sids[1]: data_2}

start_minute_loc = \
default_nyse_schedule.all_execution_minutes.get_loc(minutes[0])
self.trading_schedule.all_execution_minutes.get_loc(minutes[0])
minute_locs = [
default_nyse_schedule.all_execution_minutes.get_loc(minute)
self.trading_schedule.all_execution_minutes.get_loc(minute)
- start_minute_loc
for minute in minutes
]
Expand All @@ -723,7 +726,7 @@ def test_adjust_non_trading_minutes(self):
'close': arange(1, 781),
'volume': arange(1, 781)
}
dts = array(default_nyse_schedule.execution_minutes_for_days_in_range(
dts = array(self.trading_schedule.execution_minutes_for_days_in_range(
start_day, end_day
))
self.writer.write_cols(sid, dts, cols)
Expand Down Expand Up @@ -767,7 +770,7 @@ def test_adjust_non_trading_minutes_half_days(self):
'close': arange(1, 601),
'volume': arange(1, 601)
}
dts = array(default_nyse_schedule.execution_minutes_for_days_in_range(
dts = array(self.trading_schedule.execution_minutes_for_days_in_range(
start_day, end_day
))
self.writer.write_cols(sid, dts, cols)
Expand Down
15 changes: 6 additions & 9 deletions tests/pipeline/base.py
Expand Up @@ -2,7 +2,6 @@
Base class for Pipeline API unittests.
"""
from functools import wraps
from unittest import TestCase

import numpy as np
from numpy import arange, prod
Expand All @@ -18,10 +17,10 @@
ExplodingObject,
tmp_asset_finder,
)
from zipline.testing.fixtures import ZiplineTestCase, WithTradingSchedule

from zipline.utils.functional import dzip_exact
from zipline.utils.pandas_utils import explode
from zipline.utils.calendars import default_nyse_schedule


def with_defaults(**default_funcs):
Expand Down Expand Up @@ -51,12 +50,14 @@ def method(self, *args, **kwargs):
with_default_shape = with_defaults(shape=lambda self: self.default_shape)


class BasePipelineTestCase(TestCase):
class BasePipelineTestCase(WithTradingSchedule, ZiplineTestCase):

@classmethod
def setUpClass(cls):
def init_class_fixtures(cls):
super(BasePipelineTestCase, cls).init_class_fixtures()

cls.__calendar = date_range('2014', '2015',
freq=default_nyse_schedule.day)
freq=cls.trading_schedule.day)
cls.__assets = assets = Int64Index(arange(1, 20))
cls.__tmp_finder_ctx = tmp_asset_finder(
equities=make_simple_equity_info(
Expand All @@ -71,10 +72,6 @@ def setUpClass(cls):
include_start_date=False,
)

@classmethod
def tearDownClass(cls):
cls.__tmp_finder_ctx.__exit__()

@property
def default_shape(self):
"""Default shape for methods that build test data."""
Expand Down
9 changes: 4 additions & 5 deletions tests/pipeline/test_engine.py
Expand Up @@ -72,7 +72,6 @@
ZiplineTestCase,
)
from zipline.utils.memoize import lazyval
from zipline.utils.calendars import default_nyse_schedule


class RollingSumDifference(CustomFactor):
Expand Down Expand Up @@ -814,7 +813,7 @@ def init_class_fixtures(cls):
cls.dates = date_range(
cls.start,
cls.end,
freq=default_nyse_schedule.day,
freq=cls.trading_schedule.day,
tz='UTC',
)
cls.assets = cls.asset_finder.retrieve_all(cls.asset_ids)
Expand Down Expand Up @@ -973,7 +972,7 @@ def write_nans(self, df):
def test_SMA(self):
engine = SimplePipelineEngine(
lambda column: self.pipeline_loader,
default_nyse_schedule.all_execution_days,
self.trading_schedule.all_execution_days,
self.asset_finder,
)
window_length = 5
Expand Down Expand Up @@ -1027,7 +1026,7 @@ def test_drawdown(self):
# valuable.
engine = SimplePipelineEngine(
lambda column: self.pipeline_loader,
default_nyse_schedule.all_execution_days,
self.trading_schedule.all_execution_days,
self.asset_finder,
)
window_length = 5
Expand Down Expand Up @@ -1071,7 +1070,7 @@ class ParameterizedFactorTestCase(WithTradingEnvironment, ZiplineTestCase):
@classmethod
def init_class_fixtures(cls):
super(ParameterizedFactorTestCase, cls).init_class_fixtures()
day = default_nyse_schedule.day
day = cls.trading_schedule.day

cls.dates = dates = date_range(
'2015-02-01',
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/test_factor.py
Expand Up @@ -87,8 +87,8 @@ class Mask(Filter):

class FactorTestCase(BasePipelineTestCase):

def setUp(self):
super(FactorTestCase, self).setUp()
def init_instance_fixtures(self):
super(FactorTestCase, self).init_instance_fixtures()
self.f = F()

def test_bad_input(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/test_filter.py
Expand Up @@ -75,8 +75,8 @@ class Mask(Filter):

class FilterTestCase(BasePipelineTestCase):

def setUp(self):
super(FilterTestCase, self).setUp()
def init_instance_fixtures(self):
super(FilterTestCase, self).init_instance_fixtures()
self.f = SomeFactor()
self.g = SomeOtherFactor()

Expand Down
23 changes: 8 additions & 15 deletions tests/risk/test_risk_cumulative.py
Expand Up @@ -13,31 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import datetime
import numpy as np
import pytz
import zipline.finance.risk as risk
from zipline.utils import factory

from zipline.finance.trading import SimulationParameters, TradingEnvironment
from zipline.utils.calendars import default_nyse_schedule
from zipline.testing.fixtures import WithTradingEnvironment, ZiplineTestCase

from zipline.finance.trading import SimulationParameters
from . import answer_key
ANSWER_KEY = answer_key.ANSWER_KEY


class TestRisk(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
class TestRisk(WithTradingEnvironment, ZiplineTestCase):

@classmethod
def tearDownClass(cls):
del cls.env
def init_instance_fixtures(self):
super(TestRisk, self).init_instance_fixtures()

def setUp(self):
start_date = datetime.datetime(
year=2006,
month=1,
Expand All @@ -51,7 +44,7 @@ def setUp(self):
self.sim_params = SimulationParameters(
period_start=start_date,
period_end=end_date,
trading_schedule=default_nyse_schedule,
trading_schedule=self.trading_schedule,
)

self.algo_returns_06 = factory.create_returns_from_list(
Expand All @@ -62,7 +55,7 @@ def setUp(self):
self.cumulative_metrics_06 = risk.RiskMetricsCumulative(
self.sim_params,
treasury_curves=self.env.treasury_curves,
trading_schedule=default_nyse_schedule,
trading_schedule=self.trading_schedule,
)

for dt, returns in answer_key.RETURNS_DATA.iterrows():
Expand Down

0 comments on commit 147fd10

Please sign in to comment.