Skip to content

Commit

Permalink
DEP: Removes use of 'count'-defined test sources
Browse files Browse the repository at this point in the history
Test sources are now defined by the sim_params period_start and period_end, rather than by the period_start and a defined 'count' of bars. This allows us to consider the sim_params.period_end as the canonical definition of the end of a simulation.
  • Loading branch information
jfkirk committed May 26, 2015
1 parent 862cfbb commit 702b3aa
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 74 deletions.
1 change: 0 additions & 1 deletion tests/test_algorithm.py
Expand Up @@ -135,7 +135,6 @@ def setUp(self):
)
self.source = factory.create_minutely_trade_source(
sids,
trade_count=100,
sim_params=self.sim_params,
concurrent=True,
)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_algorithm_gen.py
Expand Up @@ -134,7 +134,6 @@ def test_generator_dates(self):
algo = TestAlgo(self, sim_params=sim_params)
trade_source = factory.create_daily_trade_source(
[8229],
200,
sim_params
)
algo.set_sources([trade_source])
Expand Down Expand Up @@ -205,7 +204,6 @@ def test_progress(self):
algo = TestAlgo(self, sim_params=sim_params)
trade_source = factory.create_daily_trade_source(
[8229],
3,
sim_params
)
algo.set_sources([trade_source])
Expand Down
1 change: 0 additions & 1 deletion tests/test_finance.py
Expand Up @@ -72,7 +72,6 @@ def test_factory_daily(self):
sim_params = factory.create_simulation_parameters()
trade_source = factory.create_daily_trade_source(
[133],
200,
sim_params
)
prev = None
Expand Down
6 changes: 3 additions & 3 deletions tests/test_transforms.py
Expand Up @@ -41,6 +41,7 @@ def wrapper(context, data):
else:
context.mins_for_days[-1] += 1

hist = context.history(2, '1d', 'close_price')
for n in (1, 2, 3):
if n in data:
if data[n].dt == dt:
Expand All @@ -53,7 +54,7 @@ def wrapper(context, data):
context.price_bars[n].append(np.nan)
context.vol_bars[n].append(0)

context.last_close_prices[n] = context.price_bars[n][-2]
context.last_close_prices[n] = hist[n][0]

if context.warmup < 0:
return f(context, data)
Expand Down Expand Up @@ -101,6 +102,7 @@ def wrapper(self, data_frequency, days=None):
initialize=initialize_with(self, tfm_name, days),
handle_data=handle_data_wrapper(f),
sim_params=sim_params,
identifiers=[1, 2, 3]
)
algo.run(source)

Expand Down Expand Up @@ -131,12 +133,10 @@ def setUpClass(cls):
cls.sim_and_source = {
'minute': (minute_sim_ps, factory.create_minutely_trade_source(
cls.sids,
trade_count=45,
sim_params=minute_sim_ps,
)),
'daily': (daily_sim_ps, factory.create_trade_source(
cls.sids,
trade_count=90,
trade_time_increment=timedelta(days=1),
sim_params=daily_sim_ps,
)),
Expand Down
66 changes: 26 additions & 40 deletions zipline/sources/test_source.py
Expand Up @@ -19,9 +19,9 @@

import pytz

from itertools import cycle
from six.moves import filter, zip
from six.moves import filter
from datetime import datetime, timedelta
import itertools
import numpy as np

from six.moves import range
Expand Down Expand Up @@ -53,9 +53,9 @@ def create_trade(sid, price, amount, datetime, source_id="test_factory"):


@with_environment()
def date_gen(start=datetime(2006, 6, 6, 12, tzinfo=pytz.utc),
def date_gen(start,
end,
delta=timedelta(minutes=1),
count=100,
repeats=None,
env=None):
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ def advance_current(cur):

# yield count trade events, all on trading days, and
# during trading hours.
for i in range(count):
while cur < end:
if repeats:
for j in range(repeats):
yield cur
Expand All @@ -98,22 +98,6 @@ def advance_current(cur):
cur = advance_current(cur)


def mock_prices(count):
"""
Utility to generate a stream of mock prices. By default
cycles through values from 0.0 to 10.0, n times.
"""
return (float(i % 10) + 1.0 for i in range(count))


def mock_volumes(count):
"""
Utility to generate a set of volumes. By default cycles
through values from 100 to 1000, incrementing by 50.
"""
return ((i * 50) % 900 + 100 for i in range(count))


class SpecificEquityTrades(object):
"""
Yields all events in event_list that match the given sid_filter.
Expand All @@ -136,30 +120,30 @@ def __init__(self, *args, **kwargs):
# Default to None for event_list and filter.
self.event_list = kwargs.get('event_list')
self.filter = kwargs.get('filter')

if self.event_list is not None:
# If event_list is provided, extract parameters from there
# This isn't really clean and ultimately I think this
# class should serve a single purpose (either take an
# event_list or autocreate events).
self.count = kwargs.get('count', len(self.event_list))
self.sids = kwargs.get(
'sids',
np.unique([event.sid for event in self.event_list]).tolist())
self.start = kwargs.get('start', self.event_list[0].dt)
self.end = kwargs.get('start', self.event_list[-1].dt)
self.end = kwargs.get('end', self.event_list[-1].dt)
self.delta = kwargs.get(
'delta',
self.event_list[1].dt - self.event_list[0].dt)
self.concurrent = kwargs.get('concurrent', False)

else:
# Unpack config dictionary with default values.
self.count = kwargs.get('count', 500)
self.sids = kwargs.get('sids', [1, 2])
self.start = kwargs.get(
'start',
datetime(2008, 6, 6, 15, tzinfo=pytz.utc))
self.end = kwargs.get(
'end',
datetime(2008, 6, 6, 15, tzinfo=pytz.utc))
self.delta = kwargs.get(
'delta',
timedelta(minutes=1))
Expand Down Expand Up @@ -201,30 +185,32 @@ def create_fresh_generator(self):
if self.concurrent:
# in this context the count is the number of
# trades per sid, not the total.
dates = date_gen(
count=self.count,
date_generator = date_gen(
start=self.start,
end=self.end,
delta=self.delta,
repeats=len(self.sids),
)
else:
dates = date_gen(
count=self.count,
date_generator = date_gen(
start=self.start,
end=self.end,
delta=self.delta
)

prices = mock_prices(self.count)
volumes = mock_volumes(self.count)

sids = cycle(self.sids)

# Combine the iterators into a single iterator of arguments
arg_gen = zip(sids, prices, volumes, dates)

# Convert argument packages into events.
unfiltered = (create_trade(*args, source_id=self.get_hash())
for args in arg_gen)
source_id = self.get_hash()

unfiltered = (
create_trade(
sid=sid,
price=float(i % 10) + 1.0,
amount=(i * 50) % 900 + 100,
datetime=date,
source_id=source_id,
) for (i, date), sid in itertools.product(
enumerate(date_generator), self.sids
)
)

# If we specified a sid filter, filter out elements that don't
# match the filter.
Expand Down
23 changes: 4 additions & 19 deletions zipline/utils/factory.py
Expand Up @@ -224,65 +224,50 @@ def create_returns_from_list(returns, sim_params):
data=returns)


def create_daily_trade_source(sids, trade_count, sim_params,
concurrent=False):
def create_daily_trade_source(sids, sim_params, concurrent=False):
"""
creates trade_count trades for each sid in sids list.
first trade will be on sim_params.period_start, and daily
thereafter for each sid. Thus, two sids should result in two trades per
day.
Important side-effect: sim_params.period_end will be modified
to match the day of the final trade.
"""
return create_trade_source(
sids,
trade_count,
timedelta(days=1),
sim_params,
concurrent=concurrent
)


def create_minutely_trade_source(sids, trade_count, sim_params,
concurrent=False):
def create_minutely_trade_source(sids, sim_params, concurrent=False):
"""
creates trade_count trades for each sid in sids list.
first trade will be on sim_params.period_start, and every minute
thereafter for each sid. Thus, two sids should result in two trades per
minute.
Important side-effect: sim_params.period_end will be modified
to match the day of the final trade.
"""
return create_trade_source(
sids,
trade_count,
timedelta(minutes=1),
sim_params,
concurrent=concurrent
)


def create_trade_source(sids, trade_count,
trade_time_increment, sim_params,
def create_trade_source(sids, trade_time_increment, sim_params,
concurrent=False):

args = tuple()
kwargs = {
'count': trade_count,
'sids': sids,
'start': sim_params.first_open,
'end': sim_params.period_end,
'delta': trade_time_increment,
'filter': sids,
'concurrent': concurrent
}
source = SpecificEquityTrades(*args, **kwargs)

# TODO: do we need to set the trading environment's end to same dt as
# the last trade in the history?
# sim_params.period_end = trade_history[-1].dt

return source


Expand Down
8 changes: 0 additions & 8 deletions zipline/utils/simfactory.py
Expand Up @@ -42,13 +42,6 @@ def create_test_zipline(**config):
else:
order_amount = 100

if 'trade_count' in config:
trade_count = config['trade_count']
else:
# to ensure all orders are filled, we provide one more
# trade than order
trade_count = 101

# -------------------
# Create the Algo
# -------------------
Expand All @@ -72,7 +65,6 @@ def create_test_zipline(**config):
else:
trade_source = factory.create_daily_trade_source(
sid_list,
trade_count,
test_algo.sim_params,
concurrent=concurrent_trades
)
Expand Down

0 comments on commit 702b3aa

Please sign in to comment.