Skip to content

Commit

Permalink
TST: Updates the tests for more coverage.
Browse files Browse the repository at this point in the history
Speeds up the NthTradingDayOfMonth and NDaysBeforeLastTradingDayOfMonth
by caching the fully computed day.
  • Loading branch information
llllllllll committed Sep 26, 2014
1 parent 06480e8 commit 5e9c3d3
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 93 deletions.
135 changes: 80 additions & 55 deletions tests/utils/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
# limitations under the License.
import datetime
import random
from itertools import imap, islice, dropwhile, product
from itertools import islice, dropwhile, product
import operator
from six.moves import range
from six.moves import range, map
from nose_parameterized import parameterized
from unittest import TestCase

from mock import patch
import pandas as pd
import numpy as np

from zipline.finance.trading import TradingEnvironment
from zipline.utils import events as events_module
from zipline.utils.events import (
EventRule,
StatelessRule,
Always,
Never,
Expand Down Expand Up @@ -175,6 +176,16 @@ def should_trigger(self, dt):
self.assertEqual(CountingRule.count, 5)


class TestEventRule(TestCase):
def test_is_abstract(self):
with self.assertRaises(TypeError):
EventRule()

def test_not_implemented(self):
with self.assertRaises(NotImplementedError):
super(Always, Always()).should_trigger('a')


class RuleTestCase(TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -183,14 +194,15 @@ def setUpClass(cls):

def setUp(self):
# Select a random sample of 5 trading days
index = random.sample(range(len(self.env.trading_days)), 5)
test_dts = [self.env.trading_days[i] for i in index]
self.open_close_times = [
self.env.get_open_and_close(dt) for dt in test_dts
]
self.trading_days = (
self.env.market_minutes_for_day(dt) for dt in test_dts
)
self.trading_days = self._get_random_days(5)

def _get_random_days(self, n):
"""
Returns a random selection n trading days.
"""
index = random.sample(range(len(self.env.trading_days)), n)
test_dts = (self.env.trading_days[i] for i in index)
return (self.env.market_minutes_for_day(dt) for dt in test_dts)

@property
def minutes(self):
Expand All @@ -210,7 +222,6 @@ def test_completeness(self):
if isinstance(v, type)
and issubclass(v, self.class_)
and v is not self.class_
and v not in self.ignored
}
ds = {
k[5:] for k in dir(self)
Expand All @@ -229,14 +240,10 @@ def setUpClass(cls):
super(TestStatelessRules, cls).setUpClass()

cls.class_ = StatelessRule
cls.ignored = {ComposedRule}

cls.sept_minutes = imap(
lambda m: m.to_datetime(),
cls.env.minutes_for_days_in_range(
datetime.date(year=2014, month=9, day=1),
datetime.date(year=2014, month=9, day=30),
)
cls.sept_days = cls.env.days_in_range(
np.datetime64(datetime.date(year=2014, month=9, day=1)),
np.datetime64(datetime.date(year=2014, month=9, day=30)),
)

cls.sept_week = cls.env.minutes_for_days_in_range(
Expand All @@ -246,18 +253,21 @@ def setUpClass(cls):

def test_Always(self):
should_trigger = Always().should_trigger
self.assertTrue(all(imap(should_trigger, self.minutes)))
self.assertTrue(all(map(should_trigger, self.minutes)))

def test_Never(self):
should_trigger = Never().should_trigger
self.assertFalse(any(imap(should_trigger, self.minutes)))
self.assertFalse(any(map(should_trigger, self.minutes)))

def test_InvertedRule(self):
rule = Always()
should_trigger = rule.should_trigger
should_not_trigger = InvertedRule(rule).should_trigger
f = lambda m: should_trigger(m) != should_not_trigger(m)
self.assertTrue(all(imap(f, self.minutes)))
self.assertTrue(all(map(f, self.minutes)))

# Test the syntax.
self.assertIsInstance(~Always(), InvertedRule)

def test_AfterOpen(self):
should_trigger = AfterOpen(minutes=5, hours=1).should_trigger
Expand All @@ -278,8 +288,8 @@ def test_BeforeClose(self):
def test_OnDate(self):
first_day = next(self.trading_days)
should_trigger = OnDate(first_day[0].date()).should_trigger
self.assertTrue(all(imap(should_trigger, first_day)))
self.assertFalse(any(imap(should_trigger, self.minutes)))
self.assertTrue(all(map(should_trigger, first_day)))
self.assertFalse(any(map(should_trigger, self.minutes)))

def _test_before_after_date(self, class_, op):
minutes = list(self.minutes)
Expand All @@ -304,7 +314,7 @@ def test_AtTime(self):
hit = []
f = lambda m: should_trigger(m) == (m.time() == time) \
and (hit.append(None) or True)
self.assertTrue(all(imap(f, self.minutes)))
self.assertTrue(all(map(f, self.minutes)))
# Make sure we actually had a bar that is the time we wanted.
self.assertTrue(hit)

Expand Down Expand Up @@ -367,35 +377,46 @@ def test_NDaysBeforeLastTradingDayOfWeek(self, n):
@parameterized.expand(param_range(30))
def test_NthTradingDayOfMonth(self, n):
should_trigger = NthTradingDayOfMonth(n).should_trigger
n_tdays = 0
prev_day = None
n_tdays = 0
for m in self.sept_minutes:
if not prev_day:
prev_day = m.date()
if prev_day < m.date():
n_tdays += 1

if should_trigger(m):
self.assertEqual(n_tdays, n)
else:
self.assertNotEqual(n_tdays, n)
prev_day = m.date()
for n_tdays, d in enumerate(self.sept_days):
for m in self.env.market_minutes_for_day(d):
if should_trigger(m):
self.assertEqual(n_tdays, n)
else:
self.assertNotEqual(n_tdays, n)

@parameterized.expand(param_range(30))
def test_NDaysBeforeLastTradingDayOfMonth(self, n):
should_trigger = NDaysBeforeLastTradingDayOfMonth(n).should_trigger
for m in self.sept_minutes:
if should_trigger(m):
n_tdays = 0
date = m.date()
next_date = self.env.next_trading_day(date)
while next_date.day > date.day:
date = next_date
next_date = self.env.next_trading_day(date)
n_tdays += 1
for n_days_before, d in enumerate(reversed(self.sept_days)):
for m in self.env.market_minutes_for_day(d):
if should_trigger(m):
self.assertEqual(n_days_before, n)
else:
self.assertNotEqual(n_days_before, n)

self.assertEqual(n_tdays, n)
@parameterized.expand([
('and', operator.and_, lambda t: t._test_composed_and),
('or', operator.or_, lambda t: t._test_composed_or),
('xor', operator.xor, lambda t: t._test_composed_xor),
])
def test_ComposedRule(self, name, composer, tester):
rule1 = Always()
rule2 = Never()

composed = composer(rule1, rule2)
self.assertIsInstance(composed, ComposedRule)
self.assertIs(composed.first, rule1)
self.assertIs(composed.second, rule2)
tester(self)(composed)

def _test_composed_and(self, rule):
self.assertFalse(any(map(rule.should_trigger, self.minutes)))

def _test_composed_or(self, rule):
self.assertTrue(all(map(rule.should_trigger, self.minutes)))

def _test_composed_xor(self, rule):
self.assertTrue(all(map(rule.should_trigger, self.minutes)))


class TestStatefulRules(RuleTestCase):
Expand All @@ -404,7 +425,6 @@ def setUpClass(cls):
super(TestStatefulRules, cls).setUpClass()

cls.class_ = StatefulRule
cls.ignored = {}

@parameterized.expand(param_range(5))
def test_DoNTimes(self, n):
Expand All @@ -414,7 +434,7 @@ def test_DoNTimes(self, n):
for n in range(n):
self.assertTrue(rule.should_trigger(next(min_gen)))

self.assertFalse(any(imap(rule.should_trigger, min_gen)))
self.assertFalse(any(map(rule.should_trigger, min_gen)))

@parameterized.expand(param_range(5))
def test_SkipNTimes(self, n):
Expand All @@ -424,10 +444,15 @@ def test_SkipNTimes(self, n):
for n in range(n):
self.assertFalse(rule.should_trigger(next(min_gen)))

self.assertTrue(any(imap(rule.should_trigger, min_gen)))
self.assertTrue(any(map(rule.should_trigger, min_gen)))

@parameterized.expand(
product(range(5), [('B', 5), ('W', 10), ('M', 50), ('Q', 50)])
)
def test_NTimesPerPeriod(self, n, period_ndays):
period, ndays = period_ndays
self.trading_days = self._get_random_days(ndays)

@parameterized.expand(product(range(5), 'BWMQ'))
def test_NTimesPerPeriod(self, n, period):
rule = NTimesPerPeriod(n=n, freq=period)

minutes = list(self.minutes)
Expand All @@ -445,4 +470,4 @@ def test_NTimesPerPeriod(self, n, period):

def test_RuleFromCallable(self):
rule = RuleFromCallable(lambda dt: True)
self.assertTrue(all(imap(rule.should_trigger, self.minutes)))
self.assertTrue(all(map(rule.should_trigger, self.minutes)))
2 changes: 1 addition & 1 deletion zipline/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Argument(namedtuple('Argument', ['name', 'default'])):
any_default = AnyDefault()
ignore = Ignore()

def __new__(cls, name, default=ignore):
def __new__(cls, name=ignore, default=ignore):
return super(Argument, cls).__new__(cls, name, default)

def __str__(self):
Expand Down

0 comments on commit 5e9c3d3

Please sign in to comment.