New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update test_algorithm.py to use WithMakeAlgo. #2171
Changes from 2 commits
372ff1b
22d8bdb
e5228e6
4a4c567
af3c6b6
a52de4c
f745d3d
4a365cd
1ddf6ef
68f0b9b
5e09bbd
524d0ea
acca906
2467856
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,6 @@ | |
|
||
import zipline.api | ||
from zipline import run_algorithm | ||
from zipline import TradingAlgorithm | ||
from zipline.api import FixedSlippage | ||
from zipline.assets import Equity, Future, Asset | ||
from zipline.assets.continuous_futures import ContinuousFuture | ||
|
@@ -105,21 +104,13 @@ | |
tmp_dir, | ||
) | ||
from zipline.testing import RecordBatchBlotter | ||
from zipline.testing.fixtures import ( | ||
WithDataPortal, | ||
WithLogger, | ||
WithSimParams, | ||
WithTradingEnvironment, | ||
WithTmpDir, | ||
ZiplineTestCase, | ||
) | ||
import zipline.testing.fixtures as zf | ||
from zipline.test_algorithms import ( | ||
access_account_in_init, | ||
access_portfolio_in_init, | ||
AmbitiousStopLimitAlgorithm, | ||
EmptyPositionsAlgorithm, | ||
InvalidOrderAlgorithm, | ||
RecordAlgorithm, | ||
FutureFlipAlgo, | ||
TestOrderAlgorithm, | ||
TestOrderPercentAlgorithm, | ||
|
@@ -197,12 +188,27 @@ | |
_multiprocess_can_split_ = False | ||
|
||
|
||
class TestRecordAlgorithm(WithSimParams, WithDataPortal, ZiplineTestCase): | ||
ASSET_FINDER_EQUITY_SIDS = 133, | ||
class TestRecord(zf.WithMakeAlgo, zf.ZiplineTestCase): | ||
ASSET_FINDER_EQUITY_SIDS = (133,) | ||
SIM_PARAMS_DATA_FREQUENCY = 'daily' | ||
DATA_PORTAL_USE_MINUTE_DATA = False | ||
|
||
def test_record_incr(self): | ||
algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env) | ||
output = algo.run(self.data_portal) | ||
|
||
def initialize(self): | ||
self.incr = 0 | ||
|
||
def handle_data(self, data): | ||
self.incr += 1 | ||
self.record(incr=self.incr) | ||
name = 'name' | ||
self.record(name, self.incr) | ||
zipline.api.record(name, self.incr, 'name2', 2, name3=self.incr) | ||
|
||
output = self.run_algorithm( | ||
initialize=initialize, | ||
handle_data=handle_data, | ||
) | ||
|
||
np.testing.assert_array_equal(output['incr'].values, | ||
range(1, len(output) + 1)) | ||
|
@@ -214,16 +220,16 @@ def test_record_incr(self): | |
range(1, len(output) + 1)) | ||
|
||
|
||
class TestMiscellaneousAPI(WithLogger, | ||
WithSimParams, | ||
WithDataPortal, | ||
ZiplineTestCase): | ||
class TestMiscellaneousAPI(zf.WithMakeAlgo, zf.ZiplineTestCase): | ||
|
||
START_DATE = pd.Timestamp('2006-01-03', tz='UTC') | ||
END_DATE = pd.Timestamp('2006-01-04', tz='UTC') | ||
SIM_PARAMS_DATA_FREQUENCY = 'minute' | ||
sids = 1, 2 | ||
|
||
# FIXME: Pass a benchmark source instead of this. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default behavior of Setting BENCHMARK_SID to None causes us to use benchmark returns retrieved from the TradingEnvironment, which uses a static set of data that's checked into git. I'm not sure what we want long-term for benchmarks in testing, but this is what we were already doing in all of these tests, and I didn't want to layer in more changes than were necessary. Once we have all these suites creating algorithms in a uniform way, it will be easier to make broad changes to how we handle benchmarks. |
||
BENCHMARK_SID = None | ||
|
||
@classmethod | ||
def make_equity_info(cls): | ||
return pd.concat(( | ||
|
@@ -291,13 +297,9 @@ def initialize(algo): | |
def handle_data(algo, data): | ||
set_cancel_policy(cancel_policy.NeverCancel()) | ||
""" | ||
|
||
algo = TradingAlgorithm(script=code, | ||
sim_params=self.sim_params, | ||
env=self.env) | ||
|
||
algo = self.make_algo(script=code) | ||
with self.assertRaises(SetCancelPolicyPostInit): | ||
algo.run(self.data_portal) | ||
algo.run() | ||
|
||
def test_cancel_policy_invalid_param(self): | ||
code = """ | ||
|
@@ -309,20 +311,15 @@ def initialize(algo): | |
def handle_data(algo, data): | ||
pass | ||
""" | ||
algo = TradingAlgorithm(script=code, | ||
sim_params=self.sim_params, | ||
env=self.env) | ||
|
||
algo = self.make_algo(script=code) | ||
with self.assertRaises(UnsupportedCancelPolicy): | ||
algo.run(self.data_portal) | ||
algo.run() | ||
|
||
def test_zipline_api_resolves_dynamically(self): | ||
# Make a dummy algo. | ||
algo = TradingAlgorithm( | ||
algo = self.make_algo( | ||
initialize=lambda context: None, | ||
handle_data=lambda context, data: None, | ||
sim_params=self.sim_params, | ||
env=self.env, | ||
) | ||
|
||
# Verify that api methods get resolved dynamically by patching them out | ||
|
@@ -348,11 +345,10 @@ def handle_data(context, data): | |
aapl_dt = data.current(sid(1), "last_traded") | ||
assert_equal(aapl_dt, get_datetime()) | ||
""" | ||
algo = TradingAlgorithm(script=algo_text, | ||
sim_params=self.sim_params, | ||
env=self.env) | ||
algo.namespace['assert_equal'] = self.assertEqual | ||
algo.run(self.data_portal) | ||
self.run_algorithm( | ||
script=algo_text, | ||
namespace={'assert_equal': self.assertEqual}, | ||
) | ||
|
||
def test_datetime_bad_params(self): | ||
algo_text = """ | ||
|
@@ -365,11 +361,9 @@ def initialize(context): | |
def handle_data(context, data): | ||
get_datetime(timezone) | ||
""" | ||
algo = self.make_algo(script=algo_text) | ||
with self.assertRaises(TypeError): | ||
algo = TradingAlgorithm(script=algo_text, | ||
sim_params=self.sim_params, | ||
env=self.env) | ||
algo.run(self.data_portal) | ||
algo.run() | ||
|
||
def test_get_environment(self): | ||
expected_env = { | ||
|
@@ -388,11 +382,7 @@ def initialize(algo): | |
def handle_data(algo, data): | ||
pass | ||
|
||
algo = TradingAlgorithm(initialize=initialize, | ||
handle_data=handle_data, | ||
sim_params=self.sim_params, | ||
env=self.env) | ||
algo.run(self.data_portal) | ||
self.run_algorithm(initialize=initialize, handle_data=handle_data) | ||
|
||
def test_get_open_orders(self): | ||
def initialize(algo): | ||
|
@@ -438,11 +428,7 @@ def handle_data(algo, data): | |
|
||
algo.minute += 1 | ||
|
||
algo = TradingAlgorithm(initialize=initialize, | ||
handle_data=handle_data, | ||
sim_params=self.sim_params, | ||
env=self.env) | ||
algo.run(self.data_portal) | ||
self.run_algorithm(initialize=initialize, handle_data=handle_data) | ||
|
||
def test_schedule_function_custom_cal(self): | ||
# run a simulation on the CME cal, and schedule a function | ||
|
@@ -477,14 +463,11 @@ def log_nyse_close(context, data): | |
context.nyse_closes.append(get_datetime()) | ||
""" | ||
|
||
algo = TradingAlgorithm( | ||
algo = self.make_algo( | ||
script=algotext, | ||
sim_params=self.sim_params, | ||
env=self.env, | ||
trading_calendar=get_calendar("CME") | ||
trading_calendar=get_calendar("CME"), | ||
) | ||
|
||
algo.run(self.data_portal) | ||
algo.run() | ||
|
||
nyse = get_calendar("NYSE") | ||
|
||
|
@@ -514,15 +497,13 @@ def my_func(context, data): | |
""" | ||
) | ||
|
||
algo = TradingAlgorithm( | ||
algo = self.make_algo( | ||
script=erroring_algotext, | ||
sim_params=self.sim_params, | ||
env=self.env, | ||
trading_calendar=get_calendar('CME'), | ||
) | ||
|
||
with self.assertRaises(ScheduleFunctionInvalidCalendar): | ||
algo.run(self.data_portal) | ||
algo.run() | ||
|
||
def test_schedule_function(self): | ||
us_eastern = pytz.timezone('US/Eastern') | ||
|
@@ -558,13 +539,11 @@ def handle_data(algo, data): | |
algo.days += 1 | ||
algo.date = algo.get_datetime().date() | ||
|
||
algo = TradingAlgorithm( | ||
algo = self.make_algo( | ||
initialize=initialize, | ||
handle_data=handle_data, | ||
sim_params=self.sim_params, | ||
env=self.env, | ||
) | ||
algo.run(self.data_portal) | ||
algo.run() | ||
|
||
self.assertEqual(algo.func_called, algo.days) | ||
|
||
|
@@ -596,12 +575,10 @@ def f(context, data): | |
def g(context, data): | ||
function_stack.append(g) | ||
|
||
algo = TradingAlgorithm( | ||
algo = self.make_algo( | ||
initialize=initialize, | ||
handle_data=handle_data, | ||
sim_params=self.sim_params, | ||
create_event_context=CallbackManager(pre, post), | ||
env=self.env, | ||
) | ||
algo.run(self.data_portal) | ||
|
||
|
@@ -633,7 +610,7 @@ def nop(*args, **kwargs): | |
return None | ||
|
||
self.sim_params.data_frequency = mode | ||
algo = TradingAlgorithm( | ||
algo = self.make_algo( | ||
initialize=nop, | ||
handle_data=nop, | ||
sim_params=self.sim_params, | ||
|
@@ -673,7 +650,7 @@ def nop(*args, **kwargs): | |
self.assertIs(composer, ComposedRule.lazy_and) | ||
|
||
def test_asset_lookup(self): | ||
algo = TradingAlgorithm(env=self.env) | ||
algo = self.make_algo() | ||
|
||
# this date doesn't matter | ||
start_session = pd.Timestamp("2000-01-01", tz="UTC") | ||
|
@@ -743,7 +720,7 @@ def test_asset_lookup(self): | |
def test_future_symbol(self): | ||
""" Tests the future_symbol API function. | ||
""" | ||
algo = TradingAlgorithm(env=self.env) | ||
algo = self.make_algo() | ||
algo.datetime = pd.Timestamp('2006-12-01', tz='UTC') | ||
|
||
# Check that we get the correct fields for the CLG06 symbol | ||
|
@@ -782,55 +759,67 @@ def test_future_symbol(self): | |
with self.assertRaises(TypeError): | ||
algo.future_symbol({'foo': 'bar'}) | ||
|
||
|
||
class TestSetSymbolLookupDate(zf.WithMakeAlgo, zf.ZiplineTestCase): | ||
# January 2006 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice choice of a month, with the alignment of the days of the week. Also 👍 on removing the use of |
||
# Su Mo Tu We Th Fr Sa | ||
# 1 2 3 4 5 6 7 | ||
# 8 9 10 11 12 13 14 | ||
# 15 16 17 18 19 20 21 | ||
# 22 23 24 25 26 27 28 | ||
# 29 30 31 | ||
START_DATE = pd.Timestamp('2006-01-03', tz='UTC') | ||
END_DATE = pd.Timestamp('2006-01-06', tz='UTC') | ||
SIM_PARAMS_START_DATE = pd.Timestamp('2006-01-04', tz='UTC') | ||
SIM_PARAMS_DATA_FREQUENCY = 'daily' | ||
DATA_PORTAL_USE_MINUTE_DATA = False | ||
BENCHMARK_SID = 3 | ||
|
||
@classmethod | ||
def make_equity_info(cls): | ||
dates = pd.date_range(cls.START_DATE, cls.END_DATE) | ||
assert len(dates) == 4, "Expected four dates." | ||
|
||
# Two assets with the same ticker, ending on days[1] and days[3], plus | ||
# a benchmark that spans the whole period. | ||
cls.sids = [1, 2, 3] | ||
cls.asset_starts = [dates[0], dates[2]] | ||
cls.asset_ends = [dates[1], dates[3]] | ||
return pd.DataFrame.from_records([ | ||
{'symbol': 'DUP', | ||
'start_date': cls.asset_starts[0], | ||
'end_date': cls.asset_ends[0], | ||
'exchange': 'TEST', | ||
'asset_name': 'FIRST'}, | ||
{'symbol': 'DUP', | ||
'start_date': cls.asset_starts[1], | ||
'end_date': cls.asset_ends[1], | ||
'exchange': 'TEST', | ||
'asset_name': 'SECOND'}, | ||
{'symbol': 'BENCH', | ||
'start_date': cls.START_DATE, | ||
'end_date': cls.END_DATE, | ||
'exchange': 'TEST', | ||
'asset_name': 'BENCHMARK'}, | ||
], index=cls.sids) | ||
|
||
def test_set_symbol_lookup_date(self): | ||
""" | ||
Test the set_symbol_lookup_date API method. | ||
""" | ||
# Note we start sid enumeration at i+3 so as not to | ||
# collide with sids [1, 2] added in the setUp() method. | ||
dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC') | ||
# Create two assets with the same symbol but different | ||
# non-overlapping date ranges. | ||
metadata = pd.DataFrame.from_records( | ||
[ | ||
{ | ||
'sid': i + 3, | ||
'symbol': 'DUP', | ||
'start_date': date.value, | ||
'end_date': (date + timedelta(days=1)).value, | ||
'exchange': 'TEST', | ||
} | ||
for i, date in enumerate(dates) | ||
] | ||
) | ||
with tmp_trading_env(equities=metadata, | ||
load=self.make_load_function()) as env: | ||
algo = TradingAlgorithm(env=env) | ||
|
||
# Set the period end to a date after the period end | ||
# dates for our assets. | ||
algo.sim_params = algo.sim_params.create_new( | ||
algo.sim_params.start_session, | ||
pd.Timestamp('2015-01-01', tz='UTC') | ||
) | ||
set_symbol_lookup_date = zipline.api.set_symbol_lookup_date | ||
|
||
# With no symbol lookup date set, we will use the period end date | ||
# for the as_of_date, resulting here in the asset with the earlier | ||
# start date being returned. | ||
result = algo.symbol('DUP') | ||
self.assertEqual(result.symbol, 'DUP') | ||
def initialize(context): | ||
set_symbol_lookup_date(self.asset_ends[0]) | ||
self.assertEqual(zipline.api.symbol('DUP').sid, self.sids[0]) | ||
|
||
# By first calling set_symbol_lookup_date, the relevant asset | ||
# should be returned by lookup_symbol | ||
for i, date in enumerate(dates): | ||
algo.set_symbol_lookup_date(date) | ||
result = algo.symbol('DUP') | ||
self.assertEqual(result.symbol, 'DUP') | ||
self.assertEqual(result.sid, i + 3) | ||
set_symbol_lookup_date(self.asset_ends[1]) | ||
self.assertEqual(zipline.api.symbol('DUP').sid, self.sids[1]) | ||
|
||
with self.assertRaises(UnsupportedDatetimeFormat): | ||
algo.set_symbol_lookup_date('foobar') | ||
set_symbol_lookup_date('foobar') | ||
|
||
self.run_algorithm(initialize=initialize) | ||
|
||
class TestTransformAlgorithm(WithLogger, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
WithDataPortal, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general theme of this PR is that I tried to "inline" algorithms into test sites if they were small and if there were assertions in the test body that depended on the behavior of the algorithm.
For tests that were just "run algorithm X", I didn't see as much value in doing the inlining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I agree with the inlining. While debugging (or just reading tests), this will reduce cognitive load.