Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test_algorithm.py to use WithMakeAlgo. #2171

Merged
merged 14 commits into from May 8, 2018
211 changes: 100 additions & 111 deletions tests/test_algorithm.py
Expand Up @@ -36,7 +36,6 @@

import zipline.api
from zipline import run_algorithm
from zipline import TradingAlgorithm
from zipline.api import FixedSlippage
from zipline.assets import Equity, Future, Asset
from zipline.assets.continuous_futures import ContinuousFuture
Expand Down Expand Up @@ -105,21 +104,13 @@
tmp_dir,
)
from zipline.testing import RecordBatchBlotter
from zipline.testing.fixtures import (
WithDataPortal,
WithLogger,
WithSimParams,
WithTradingEnvironment,
WithTmpDir,
ZiplineTestCase,
)
import zipline.testing.fixtures as zf
from zipline.test_algorithms import (
access_account_in_init,
access_portfolio_in_init,
AmbitiousStopLimitAlgorithm,
EmptyPositionsAlgorithm,
InvalidOrderAlgorithm,
RecordAlgorithm,
FutureFlipAlgo,
TestOrderAlgorithm,
TestOrderPercentAlgorithm,
Expand Down Expand Up @@ -197,12 +188,27 @@
_multiprocess_can_split_ = False


class TestRecordAlgorithm(WithSimParams, WithDataPortal, ZiplineTestCase):
ASSET_FINDER_EQUITY_SIDS = 133,
class TestRecord(zf.WithMakeAlgo, zf.ZiplineTestCase):
ASSET_FINDER_EQUITY_SIDS = (133,)
SIM_PARAMS_DATA_FREQUENCY = 'daily'
DATA_PORTAL_USE_MINUTE_DATA = False

def test_record_incr(self):
algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
output = algo.run(self.data_portal)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general theme of this PR is that I tried to "inline" algorithms into test sites if they were small and if there were assertions in the test body that depended on the behavior of the algorithm.

For tests that were just "run algorithm X", I didn't see as much value in doing the inlining.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I agree with the inlining. While debugging (or just reading tests), this will reduce cognitive load.

def initialize(self):
self.incr = 0

def handle_data(self, data):
self.incr += 1
self.record(incr=self.incr)
name = 'name'
self.record(name, self.incr)
zipline.api.record(name, self.incr, 'name2', 2, name3=self.incr)

output = self.run_algorithm(
initialize=initialize,
handle_data=handle_data,
)

np.testing.assert_array_equal(output['incr'].values,
range(1, len(output) + 1))
Expand All @@ -214,16 +220,16 @@ def test_record_incr(self):
range(1, len(output) + 1))


class TestMiscellaneousAPI(WithLogger,
WithSimParams,
WithDataPortal,
ZiplineTestCase):
class TestMiscellaneousAPI(zf.WithMakeAlgo, zf.ZiplineTestCase):

START_DATE = pd.Timestamp('2006-01-03', tz='UTC')
END_DATE = pd.Timestamp('2006-01-04', tz='UTC')
SIM_PARAMS_DATA_FREQUENCY = 'minute'
sids = 1, 2

# FIXME: Pass a benchmark source instead of this.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behavior of WithMakeAlgo is to use the first equity sid as a benchmark, but that often doesn't work, because if you run an algorithm from START_DATE to END_DATE, then you don't have enough days of historical data to calculate benchmark returns (in particular, you need the close from the day before START_DATE to compute a daily benchmark return for the first day of the backtest.

Setting BENCHMARK_SID to None causes us to use benchmark returns retrieved from the TradingEnvironment, which uses a static set of data that's checked into git. I'm not sure what we want long-term for benchmarks in testing, but this is what we were already doing in all of these tests, and I didn't want to layer in more changes than were necessary. Once we have all these suites creating algorithms in a uniform way, it will be easier to make broad changes to how we handle benchmarks.

BENCHMARK_SID = None

@classmethod
def make_equity_info(cls):
return pd.concat((
Expand Down Expand Up @@ -291,13 +297,9 @@ def initialize(algo):
def handle_data(algo, data):
set_cancel_policy(cancel_policy.NeverCancel())
"""

algo = TradingAlgorithm(script=code,
sim_params=self.sim_params,
env=self.env)

algo = self.make_algo(script=code)
with self.assertRaises(SetCancelPolicyPostInit):
algo.run(self.data_portal)
algo.run()

def test_cancel_policy_invalid_param(self):
code = """
Expand All @@ -309,20 +311,15 @@ def initialize(algo):
def handle_data(algo, data):
pass
"""
algo = TradingAlgorithm(script=code,
sim_params=self.sim_params,
env=self.env)

algo = self.make_algo(script=code)
with self.assertRaises(UnsupportedCancelPolicy):
algo.run(self.data_portal)
algo.run()

def test_zipline_api_resolves_dynamically(self):
# Make a dummy algo.
algo = TradingAlgorithm(
algo = self.make_algo(
initialize=lambda context: None,
handle_data=lambda context, data: None,
sim_params=self.sim_params,
env=self.env,
)

# Verify that api methods get resolved dynamically by patching them out
Expand All @@ -348,11 +345,10 @@ def handle_data(context, data):
aapl_dt = data.current(sid(1), "last_traded")
assert_equal(aapl_dt, get_datetime())
"""
algo = TradingAlgorithm(script=algo_text,
sim_params=self.sim_params,
env=self.env)
algo.namespace['assert_equal'] = self.assertEqual
algo.run(self.data_portal)
self.run_algorithm(
script=algo_text,
namespace={'assert_equal': self.assertEqual},
)

def test_datetime_bad_params(self):
algo_text = """
Expand All @@ -365,11 +361,9 @@ def initialize(context):
def handle_data(context, data):
get_datetime(timezone)
"""
algo = self.make_algo(script=algo_text)
with self.assertRaises(TypeError):
algo = TradingAlgorithm(script=algo_text,
sim_params=self.sim_params,
env=self.env)
algo.run(self.data_portal)
algo.run()

def test_get_environment(self):
expected_env = {
Expand All @@ -388,11 +382,7 @@ def initialize(algo):
def handle_data(algo, data):
pass

algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
env=self.env)
algo.run(self.data_portal)
self.run_algorithm(initialize=initialize, handle_data=handle_data)

def test_get_open_orders(self):
def initialize(algo):
Expand Down Expand Up @@ -438,11 +428,7 @@ def handle_data(algo, data):

algo.minute += 1

algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
env=self.env)
algo.run(self.data_portal)
self.run_algorithm(initialize=initialize, handle_data=handle_data)

def test_schedule_function_custom_cal(self):
# run a simulation on the CME cal, and schedule a function
Expand Down Expand Up @@ -477,14 +463,11 @@ def log_nyse_close(context, data):
context.nyse_closes.append(get_datetime())
"""

algo = TradingAlgorithm(
algo = self.make_algo(
script=algotext,
sim_params=self.sim_params,
env=self.env,
trading_calendar=get_calendar("CME")
trading_calendar=get_calendar("CME"),
)

algo.run(self.data_portal)
algo.run()

nyse = get_calendar("NYSE")

Expand Down Expand Up @@ -514,15 +497,13 @@ def my_func(context, data):
"""
)

algo = TradingAlgorithm(
algo = self.make_algo(
script=erroring_algotext,
sim_params=self.sim_params,
env=self.env,
trading_calendar=get_calendar('CME'),
)

with self.assertRaises(ScheduleFunctionInvalidCalendar):
algo.run(self.data_portal)
algo.run()

def test_schedule_function(self):
us_eastern = pytz.timezone('US/Eastern')
Expand Down Expand Up @@ -558,13 +539,11 @@ def handle_data(algo, data):
algo.days += 1
algo.date = algo.get_datetime().date()

algo = TradingAlgorithm(
algo = self.make_algo(
initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
env=self.env,
)
algo.run(self.data_portal)
algo.run()

self.assertEqual(algo.func_called, algo.days)

Expand Down Expand Up @@ -596,12 +575,10 @@ def f(context, data):
def g(context, data):
function_stack.append(g)

algo = TradingAlgorithm(
algo = self.make_algo(
initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
create_event_context=CallbackManager(pre, post),
env=self.env,
)
algo.run(self.data_portal)

Expand Down Expand Up @@ -633,7 +610,7 @@ def nop(*args, **kwargs):
return None

self.sim_params.data_frequency = mode
algo = TradingAlgorithm(
algo = self.make_algo(
initialize=nop,
handle_data=nop,
sim_params=self.sim_params,
Expand Down Expand Up @@ -673,7 +650,7 @@ def nop(*args, **kwargs):
self.assertIs(composer, ComposedRule.lazy_and)

def test_asset_lookup(self):
algo = TradingAlgorithm(env=self.env)
algo = self.make_algo()

# this date doesn't matter
start_session = pd.Timestamp("2000-01-01", tz="UTC")
Expand Down Expand Up @@ -743,7 +720,7 @@ def test_asset_lookup(self):
def test_future_symbol(self):
""" Tests the future_symbol API function.
"""
algo = TradingAlgorithm(env=self.env)
algo = self.make_algo()
algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')

# Check that we get the correct fields for the CLG06 symbol
Expand Down Expand Up @@ -782,55 +759,67 @@ def test_future_symbol(self):
with self.assertRaises(TypeError):
algo.future_symbol({'foo': 'bar'})


class TestSetSymbolLookupDate(zf.WithMakeAlgo, zf.ZiplineTestCase):
# January 2006
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice choice of a month, with the alignment of the days of the week.

Also 👍 on removing the use of timedelta in constructing dates.

# Su Mo Tu We Th Fr Sa
# 1 2 3 4 5 6 7
# 8 9 10 11 12 13 14
# 15 16 17 18 19 20 21
# 22 23 24 25 26 27 28
# 29 30 31
START_DATE = pd.Timestamp('2006-01-03', tz='UTC')
END_DATE = pd.Timestamp('2006-01-06', tz='UTC')
SIM_PARAMS_START_DATE = pd.Timestamp('2006-01-04', tz='UTC')
SIM_PARAMS_DATA_FREQUENCY = 'daily'
DATA_PORTAL_USE_MINUTE_DATA = False
BENCHMARK_SID = 3

@classmethod
def make_equity_info(cls):
dates = pd.date_range(cls.START_DATE, cls.END_DATE)
assert len(dates) == 4, "Expected four dates."

# Two assets with the same ticker, ending on days[1] and days[3], plus
# a benchmark that spans the whole period.
cls.sids = [1, 2, 3]
cls.asset_starts = [dates[0], dates[2]]
cls.asset_ends = [dates[1], dates[3]]
return pd.DataFrame.from_records([
{'symbol': 'DUP',
'start_date': cls.asset_starts[0],
'end_date': cls.asset_ends[0],
'exchange': 'TEST',
'asset_name': 'FIRST'},
{'symbol': 'DUP',
'start_date': cls.asset_starts[1],
'end_date': cls.asset_ends[1],
'exchange': 'TEST',
'asset_name': 'SECOND'},
{'symbol': 'BENCH',
'start_date': cls.START_DATE,
'end_date': cls.END_DATE,
'exchange': 'TEST',
'asset_name': 'BENCHMARK'},
], index=cls.sids)

def test_set_symbol_lookup_date(self):
"""
Test the set_symbol_lookup_date API method.
"""
# Note we start sid enumeration at i+3 so as not to
# collide with sids [1, 2] added in the setUp() method.
dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
# Create two assets with the same symbol but different
# non-overlapping date ranges.
metadata = pd.DataFrame.from_records(
[
{
'sid': i + 3,
'symbol': 'DUP',
'start_date': date.value,
'end_date': (date + timedelta(days=1)).value,
'exchange': 'TEST',
}
for i, date in enumerate(dates)
]
)
with tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
algo = TradingAlgorithm(env=env)

# Set the period end to a date after the period end
# dates for our assets.
algo.sim_params = algo.sim_params.create_new(
algo.sim_params.start_session,
pd.Timestamp('2015-01-01', tz='UTC')
)
set_symbol_lookup_date = zipline.api.set_symbol_lookup_date

# With no symbol lookup date set, we will use the period end date
# for the as_of_date, resulting here in the asset with the earlier
# start date being returned.
result = algo.symbol('DUP')
self.assertEqual(result.symbol, 'DUP')
def initialize(context):
set_symbol_lookup_date(self.asset_ends[0])
self.assertEqual(zipline.api.symbol('DUP').sid, self.sids[0])

# By first calling set_symbol_lookup_date, the relevant asset
# should be returned by lookup_symbol
for i, date in enumerate(dates):
algo.set_symbol_lookup_date(date)
result = algo.symbol('DUP')
self.assertEqual(result.symbol, 'DUP')
self.assertEqual(result.sid, i + 3)
set_symbol_lookup_date(self.asset_ends[1])
self.assertEqual(zipline.api.symbol('DUP').sid, self.sids[1])

with self.assertRaises(UnsupportedDatetimeFormat):
algo.set_symbol_lookup_date('foobar')
set_symbol_lookup_date('foobar')

self.run_algorithm(initialize=initialize)

class TestTransformAlgorithm(WithLogger,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestTransformAlgorithm was kind of a mess. Several of the tests in this suite turned out to be complicated no-ops (e.g., the tests for ordering futures ran an algorithm that only ordered equities). Also, none of these tests were about the old transform API, which is what the name of this suite originally referenced. It ultimately seemed easier and clearer to just break this up and re-write the tests worth keeping. Most of the re-written tests now live in test_ordering.py.

WithDataPortal,
Expand Down
12 changes: 0 additions & 12 deletions zipline/test_algorithms.py
Expand Up @@ -241,18 +241,6 @@ def handle_data(self, data):
pass


class RecordAlgorithm(TradingAlgorithm):
def initialize(self):
self.incr = 0

def handle_data(self, data):
self.incr += 1
self.record(incr=self.incr)
name = 'name'
self.record(name, self.incr)
record(name, self.incr, 'name2', 2, name3=self.incr)


class TestOrderAlgorithm(TradingAlgorithm):
def initialize(self):
self.incr = 0
Expand Down