Skip to content

Commit

Permalink
MAINT: Update TestMiscellaneousAPI to use fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
Scott Sanderson committed May 3, 2018
1 parent 3b96042 commit e01ce26
Showing 1 changed file with 80 additions and 97 deletions.
177 changes: 80 additions & 97 deletions tests/test_algorithm.py
Expand Up @@ -220,16 +220,16 @@ def handle_data(self, data):
range(1, len(output) + 1))


class TestMiscellaneousAPI(WithLogger,
WithSimParams,
WithDataPortal,
ZiplineTestCase):
class TestMiscellaneousAPI(zf.WithMakeAlgo, zf.ZiplineTestCase):

START_DATE = pd.Timestamp('2006-01-03', tz='UTC')
END_DATE = pd.Timestamp('2006-01-04', tz='UTC')
SIM_PARAMS_DATA_FREQUENCY = 'minute'
sids = 1, 2

# FIXME: Pass a benchmark source instead of this.
BENCHMARK_SID = None

@classmethod
def make_equity_info(cls):
return pd.concat((
Expand Down Expand Up @@ -297,13 +297,9 @@ def initialize(algo):
def handle_data(algo, data):
set_cancel_policy(cancel_policy.NeverCancel())
"""

algo = TradingAlgorithm(script=code,
sim_params=self.sim_params,
env=self.env)

algo = self.make_algo(script=code)
with self.assertRaises(SetCancelPolicyPostInit):
algo.run(self.data_portal)
algo.run()

def test_cancel_policy_invalid_param(self):
code = """
Expand All @@ -315,20 +311,15 @@ def initialize(algo):
def handle_data(algo, data):
pass
"""
algo = TradingAlgorithm(script=code,
sim_params=self.sim_params,
env=self.env)

algo = self.make_algo(script=code)
with self.assertRaises(UnsupportedCancelPolicy):
algo.run(self.data_portal)
algo.run()

def test_zipline_api_resolves_dynamically(self):
# Make a dummy algo.
algo = TradingAlgorithm(
algo = self.make_algo(
initialize=lambda context: None,
handle_data=lambda context, data: None,
sim_params=self.sim_params,
env=self.env,
)

# Verify that api methods get resolved dynamically by patching them out
Expand All @@ -354,11 +345,10 @@ def handle_data(context, data):
aapl_dt = data.current(sid(1), "last_traded")
assert_equal(aapl_dt, get_datetime())
"""
algo = TradingAlgorithm(script=algo_text,
sim_params=self.sim_params,
env=self.env)
algo.namespace['assert_equal'] = self.assertEqual
algo.run(self.data_portal)
self.run_algorithm(
script=algo_text,
namespace={'assert_equal': self.assertEqual},
)

def test_datetime_bad_params(self):
algo_text = """
Expand All @@ -371,11 +361,9 @@ def initialize(context):
def handle_data(context, data):
get_datetime(timezone)
"""
algo = self.make_algo(script=algo_text)
with self.assertRaises(TypeError):
algo = TradingAlgorithm(script=algo_text,
sim_params=self.sim_params,
env=self.env)
algo.run(self.data_portal)
algo.run()

def test_get_environment(self):
expected_env = {
Expand All @@ -394,11 +382,7 @@ def initialize(algo):
def handle_data(algo, data):
pass

algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
env=self.env)
algo.run(self.data_portal)
self.run_algorithm(initialize=initialize, handle_data=handle_data)

def test_get_open_orders(self):
def initialize(algo):
Expand Down Expand Up @@ -444,11 +428,7 @@ def handle_data(algo, data):

algo.minute += 1

algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
env=self.env)
algo.run(self.data_portal)
self.run_algorithm(initialize=initialize, handle_data=handle_data)

def test_schedule_function_custom_cal(self):
# run a simulation on the CME cal, and schedule a function
Expand Down Expand Up @@ -483,14 +463,11 @@ def log_nyse_close(context, data):
context.nyse_closes.append(get_datetime())
"""

algo = TradingAlgorithm(
algo = self.make_algo(
script=algotext,
sim_params=self.sim_params,
env=self.env,
trading_calendar=get_calendar("CME")
trading_calendar=get_calendar("CME"),
)

algo.run(self.data_portal)
algo.run()

nyse = get_calendar("NYSE")

Expand Down Expand Up @@ -520,15 +497,13 @@ def my_func(context, data):
"""
)

algo = TradingAlgorithm(
algo = self.make_algo(
script=erroring_algotext,
sim_params=self.sim_params,
env=self.env,
trading_calendar=get_calendar('CME'),
)

with self.assertRaises(ScheduleFunctionInvalidCalendar):
algo.run(self.data_portal)
algo.run()

def test_schedule_function(self):
us_eastern = pytz.timezone('US/Eastern')
Expand Down Expand Up @@ -564,13 +539,11 @@ def handle_data(algo, data):
algo.days += 1
algo.date = algo.get_datetime().date()

algo = TradingAlgorithm(
algo = self.make_algo(
initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
env=self.env,
)
algo.run(self.data_portal)
algo.run()

self.assertEqual(algo.func_called, algo.days)

Expand Down Expand Up @@ -602,12 +575,10 @@ def f(context, data):
def g(context, data):
function_stack.append(g)

algo = TradingAlgorithm(
algo = self.make_algo(
initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params,
create_event_context=CallbackManager(pre, post),
env=self.env,
)
algo.run(self.data_portal)

Expand Down Expand Up @@ -639,7 +610,7 @@ def nop(*args, **kwargs):
return None

self.sim_params.data_frequency = mode
algo = TradingAlgorithm(
algo = self.make_algo(
initialize=nop,
handle_data=nop,
sim_params=self.sim_params,
Expand Down Expand Up @@ -679,7 +650,7 @@ def nop(*args, **kwargs):
self.assertIs(composer, ComposedRule.lazy_and)

def test_asset_lookup(self):
algo = TradingAlgorithm(env=self.env)
algo = self.make_algo()

# this date doesn't matter
start_session = pd.Timestamp("2000-01-01", tz="UTC")
Expand Down Expand Up @@ -749,7 +720,7 @@ def test_asset_lookup(self):
def test_future_symbol(self):
""" Tests the future_symbol API function.
"""
algo = TradingAlgorithm(env=self.env)
algo = self.make_algo()
algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')

# Check that we get the correct fields for the CLG06 symbol
Expand Down Expand Up @@ -788,55 +759,67 @@ def test_future_symbol(self):
with self.assertRaises(TypeError):
algo.future_symbol({'foo': 'bar'})


class TestSetSymbolLookupDate(zf.WithMakeAlgo, zf.ZiplineTestCase):
# January 2006
# Su Mo Tu We Th Fr Sa
# 1 2 3 4 5 6 7
# 8 9 10 11 12 13 14
# 15 16 17 18 19 20 21
# 22 23 24 25 26 27 28
# 29 30 31
START_DATE = pd.Timestamp('2006-01-03', tz='UTC')
END_DATE = pd.Timestamp('2006-01-06', tz='UTC')
SIM_PARAMS_START_DATE = pd.Timestamp('2006-01-04', tz='UTC')
SIM_PARAMS_DATA_FREQUENCY = 'daily'
DATA_PORTAL_USE_MINUTE_DATA = False
BENCHMARK_SID = 3

@classmethod
def make_equity_info(cls):
dates = pd.date_range(cls.START_DATE, cls.END_DATE)
assert len(dates) == 4, "Expected four dates."

# Two assets with the same ticker, ending on days[1] and days[3], plus
# a benchmark that spans the whole period.
cls.sids = [1, 2, 3]
cls.asset_starts = [dates[0], dates[2]]
cls.asset_ends = [dates[1], dates[3]]
return pd.DataFrame.from_records([
{'symbol': 'DUP',
'start_date': cls.asset_starts[0],
'end_date': cls.asset_ends[0],
'exchange': 'TEST',
'asset_name': 'FIRST'},
{'symbol': 'DUP',
'start_date': cls.asset_starts[1],
'end_date': cls.asset_ends[1],
'exchange': 'TEST',
'asset_name': 'SECOND'},
{'symbol': 'BENCH',
'start_date': cls.START_DATE,
'end_date': cls.END_DATE,
'exchange': 'TEST',
'asset_name': 'BENCHMARK'},
], index=cls.sids)

def test_set_symbol_lookup_date(self):
"""
Test the set_symbol_lookup_date API method.
"""
# Note we start sid enumeration at i+3 so as not to
# collide with sids [1, 2] added in the setUp() method.
dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
# Create two assets with the same symbol but different
# non-overlapping date ranges.
metadata = pd.DataFrame.from_records(
[
{
'sid': i + 3,
'symbol': 'DUP',
'start_date': date.value,
'end_date': (date + timedelta(days=1)).value,
'exchange': 'TEST',
}
for i, date in enumerate(dates)
]
)
with tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
algo = TradingAlgorithm(env=env)

# Set the period end to a date after the period end
# dates for our assets.
algo.sim_params = algo.sim_params.create_new(
algo.sim_params.start_session,
pd.Timestamp('2015-01-01', tz='UTC')
)
set_symbol_lookup_date = zipline.api.set_symbol_lookup_date

# With no symbol lookup date set, we will use the period end date
# for the as_of_date, resulting here in the asset with the earlier
# start date being returned.
result = algo.symbol('DUP')
self.assertEqual(result.symbol, 'DUP')
def initialize(context):
set_symbol_lookup_date(self.asset_ends[0])
self.assertEqual(zipline.api.symbol('DUP').sid, self.sids[0])

# By first calling set_symbol_lookup_date, the relevant asset
# should be returned by lookup_symbol
for i, date in enumerate(dates):
algo.set_symbol_lookup_date(date)
result = algo.symbol('DUP')
self.assertEqual(result.symbol, 'DUP')
self.assertEqual(result.sid, i + 3)
set_symbol_lookup_date(self.asset_ends[1])
self.assertEqual(zipline.api.symbol('DUP').sid, self.sids[1])

with self.assertRaises(UnsupportedDatetimeFormat):
algo.set_symbol_lookup_date('foobar')
set_symbol_lookup_date('foobar')

self.run_algorithm(initialize=initialize)

class TestTransformAlgorithm(WithLogger,
WithDataPortal,
Expand Down

0 comments on commit e01ce26

Please sign in to comment.