diff --git a/tests/data/test_us_equity_pricing.py b/tests/data/test_us_equity_pricing.py index e53cd7a9fd..7772fab03d 100644 --- a/tests/data/test_us_equity_pricing.py +++ b/tests/data/test_us_equity_pricing.py @@ -74,6 +74,7 @@ index=arange(1, 7), columns=['start_date', 'end_date'], ).astype(datetime64) +EQUITY_INFO['symbol'] = [chr(ord('A') + n) for n in range(len(EQUITY_INFO))] TEST_QUERY_ASSETS = EQUITY_INFO.index diff --git a/tests/pipeline/test_us_equity_pricing_loader.py b/tests/pipeline/test_us_equity_pricing_loader.py index 5d8eaafb1b..527bec6116 100644 --- a/tests/pipeline/test_us_equity_pricing_loader.py +++ b/tests/pipeline/test_us_equity_pricing_loader.py @@ -91,6 +91,7 @@ index=arange(1, 7), columns=['start_date', 'end_date'], ).astype(datetime64) +EQUITY_INFO['symbol'] = [chr(ord('A') + n) for n in range(len(EQUITY_INFO))] TEST_QUERY_ASSETS = EQUITY_INFO.index diff --git a/tests/resources/example_data.tar.gz b/tests/resources/example_data.tar.gz index 3187857baf..23e134952b 100644 Binary files a/tests/resources/example_data.tar.gz and b/tests/resources/example_data.tar.gz differ diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index dc9f890b06..781d60086b 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -781,7 +781,10 @@ class TestTransformAlgorithm(WithLogger, @classmethod def make_futures_info(cls): - return pd.DataFrame.from_dict({3: {'multiplier': 10}}, 'index') + return pd.DataFrame.from_dict( + {3: {'multiplier': 10, 'symbol': 'F'}}, + orient='index', + ) @classmethod def make_equity_daily_bar_data(cls): @@ -989,6 +992,7 @@ def test_minute_data(self, algo_class): 'start_date': start_session, 'end_date': period_end + timedelta(days=1) }] * 2) + equities['symbol'] = ['A', 'B'] with TempDirectory() as tempdir, \ tmp_trading_env(equities=equities) as env: sim_params = SimulationParameters( @@ -2849,6 +2853,7 @@ def test_set_max_order_count(self): metadata = pd.DataFrame.from_dict( { 1: { + 'symbol': 'SYM', 'start_date': start, 'end_date': start + timedelta(days=6) }, @@ -2976,6 +2981,7 @@ def handle_data(algo, data): def test_asset_date_bounds(self): metadata = pd.DataFrame([{ + 'symbol': 'SYM', 'start_date': self.sim_params.start_session, 'end_date': '2020-01-01', }]) @@ -2995,6 +3001,7 @@ def test_asset_date_bounds(self): algo.run(data_portal) metadata = pd.DataFrame([{ + 'symbol': 'SYM', 'start_date': '1989-01-01', 'end_date': '1990-01-01', }]) @@ -3015,6 +3022,7 @@ def test_asset_date_bounds(self): algo.run(data_portal) metadata = pd.DataFrame([{ + 'symbol': 'SYM', 'start_date': '2020-01-01', 'end_date': '2021-01-01', }]) diff --git a/tests/test_assets.py b/tests/test_assets.py index d9340b308a..b81fac5702 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -18,6 +18,7 @@ """ from contextlib import contextmanager from datetime import datetime, timedelta +from functools import partial import pickle import sys from types import GetSetDescriptorType @@ -25,7 +26,6 @@ import uuid import warnings -from nose.tools import raises from nose_parameterized import parameterized from numpy import full, int32, int64 import pandas as pd @@ -39,7 +39,6 @@ Future, AssetDBWriter, AssetFinder, - AssetFinderCachedEquities, ) from zipline.assets.synthetic import ( make_commodity_future_info, @@ -341,7 +340,6 @@ def test_repr(self): self.assertIn("tick_size=0.01", reprd) self.assertIn("multiplier=500", reprd) - @raises(AssertionError) def test_reduce(self): assert_equal( pickle.loads(pickle.dumps(self.future)).to_dict(), @@ -485,6 +483,97 @@ def test_lookup_symbol_fuzzy(self): self.assertEqual(2, finder.lookup_symbol('BRK_A', None, fuzzy=True)) self.assertEqual(2, finder.lookup_symbol('BRK_A', dt, fuzzy=True)) + def test_lookup_symbol_change_ticker(self): + T = partial(pd.Timestamp, tz='utc') + metadata = pd.DataFrame.from_records( + [ + # sid 0 + { + 'symbol': 'A', + 'start_date': T('2014-01-01'), + 'end_date': T('2014-01-05'), + }, + { + 'symbol': 'B', + 'start_date': T('2014-01-06'), + 'end_date': T('2014-01-10'), + }, + + # sid 1 + { + 'symbol': 'C', + 'start_date': T('2014-01-01'), + 'end_date': T('2014-01-05'), + }, + { + 'symbol': 'A', # claiming the unused symbol 'A' + 'start_date': T('2014-01-06'), + 'end_date': T('2014-01-10'), + }, + ], + index=[0, 0, 1, 1], + ) + self.write_assets(equities=metadata) + finder = self.asset_finder + + # note: these assertions walk forward in time, starting at assertions + # about ownership before the start_date and ending with assertions + # after the end_date; new assertions should be inserted in the correct + # locations + + # no one held 'A' before 01 + with self.assertRaises(SymbolNotFound): + finder.lookup_symbol('A', T('2013-12-31')) + + # no one held 'C' before 01 + with self.assertRaises(SymbolNotFound): + finder.lookup_symbol('C', T('2013-12-31')) + + for asof in pd.date_range('2014-01-01', '2014-01-05', tz='utc'): + # from 01 through 05 sid 0 held 'A' + assert_equal( + finder.lookup_symbol('A', asof), + finder.retrieve_asset(0), + msg=str(asof), + ) + + # from 01 through 05 sid 1 held 'C' + assert_equal( + finder.lookup_symbol('C', asof), + finder.retrieve_asset(1), + msg=str(asof), + ) + + # no one held 'B' before 06 + with self.assertRaises(SymbolNotFound): + finder.lookup_symbol('B', T('2014-01-05')) + + # no one held 'C' after 06, however, no one has claimed it yet + # so it still maps to sid 1 + assert_equal( + finder.lookup_symbol('C', T('2014-01-07')), + finder.retrieve_asset(1), + ) + + for asof in pd.date_range('2014-01-06', '2014-01-11', tz='utc'): + # from 06 through 10 sid 0 held 'B' + # we test through the 11th because sid 1 is the last to hold 'B' + # so it should ffill + assert_equal( + finder.lookup_symbol('B', asof), + finder.retrieve_asset(0), + msg=str(asof), + ) + + # from 06 through 10 sid 1 held 'A' + # we test through the 11th because sid 1 is the last to hold 'A' + # so it should ffill + assert_equal( + finder.lookup_symbol('A', asof), + finder.retrieve_asset(1), + msg=str(asof), + ) + def test_lookup_symbol(self): # Incrementing by two so that start and end dates for each @@ -519,27 +608,7 @@ def test_lookup_symbol(self): self.assertEqual(result.symbol, 'EXISTING') self.assertEqual(result.sid, i) - def test_lookup_symbol_from_multiple_valid(self): - # This test asserts that we resolve conflicts in accordance with the - # following rules when we have multiple assets holding the same symbol - # at the same time: - - # If multiple SIDs exist for symbol S at time T, return the candidate - # SID whose start_date is highest. (200 cases) - - # If multiple SIDs exist for symbol S at time T, the best candidate - # SIDs share the highest start_date, return the SID with the highest - # end_date. (34 cases) - - # It is the opinion of the author (ssanderson) that we should consider - # this malformed input and fail here. But this is the current indended - # behavior of the code, and I accidentally broke it while refactoring. - # These will serve as regression tests until the time comes that we - # decide to enforce this as an error. - - # See https://github.com/quantopian/zipline/issues/837 for more - # details. - + def test_fail_to_write_overlapping_data(self): df = pd.DataFrame.from_records( [ { @@ -568,22 +637,22 @@ def test_lookup_symbol_from_multiple_valid(self): ] ) - self.write_assets(equities=df) - - def check(expected_sid, date): - result = self.asset_finder.lookup_symbol( - 'MULTIPLE', date, - ) - self.assertEqual(result.symbol, 'MULTIPLE') - self.assertEqual(result.sid, expected_sid) + with self.assertRaises(ValueError) as e: + self.write_assets(equities=df) - # Sids 1 and 2 are eligible here. We should get asset 2 because it - # has the later end_date. - check(2, pd.Timestamp('2010-12-31')) - - # Sids 1, 2, and 3 are eligible here. We should get sid 3 because - # it has a later start_date - check(3, pd.Timestamp('2011-01-01')) + self.assertEqual( + str(e.exception), + "Ambiguous ownership for 1 symbol, multiple assets held the" + " following symbols:\n" + "MULTIPLE:\n" + " intersections: (('2010-01-01 00:00:00', '2012-01-01 00:00:00')," + " ('2011-01-01 00:00:00', '2012-01-01 00:00:00'))\n" + " start_date end_date\n" + " sid \n" + " 1 2010-01-01 2012-01-01\n" + " 2 2010-01-01 2013-01-01\n" + " 3 2011-01-01 2012-01-01" + ) def test_lookup_generic(self): """ @@ -1000,14 +1069,6 @@ def test_error_message_plurality(self, ) -class AssetFinderCachedEquitiesTestCase(AssetFinderTestCase): - asset_finder_type = AssetFinderCachedEquities - - def write_assets(self, **kwargs): - super(AssetFinderCachedEquitiesTestCase, self).write_assets(**kwargs) - self.asset_finder.rehash_equities() - - class TestFutureChain(WithAssetFinder, ZiplineTestCase): @classmethod def make_futures_info(cls): @@ -1259,15 +1320,23 @@ def test_check_version(self): version_table = self.metadata.tables['version_info'] # This should not raise an error - check_version_info(version_table, ASSET_DB_VERSION) + check_version_info(self.engine, version_table, ASSET_DB_VERSION) # This should fail because the version is too low with self.assertRaises(AssetDBVersionError): - check_version_info(version_table, ASSET_DB_VERSION - 1) + check_version_info( + self.engine, + version_table, + ASSET_DB_VERSION - 1, + ) # This should fail because the version is too high with self.assertRaises(AssetDBVersionError): - check_version_info(version_table, ASSET_DB_VERSION + 1) + check_version_info( + self.engine, + version_table, + ASSET_DB_VERSION + 1, + ) def test_write_version(self): version_table = self.metadata.tables['version_info'] @@ -1279,24 +1348,24 @@ def test_write_version(self): # This should fail because the table has no version info and is, # therefore, consdered v0 with self.assertRaises(AssetDBVersionError): - check_version_info(version_table, -2) + check_version_info(self.engine, version_table, -2) # This should not raise an error because the version has been written - write_version_info(version_table, -2) - check_version_info(version_table, -2) + write_version_info(self.engine, version_table, -2) + check_version_info(self.engine, version_table, -2) # Assert that the version is in the table and correct self.assertEqual(sa.select((version_table.c.version,)).scalar(), -2) # Assert that trying to overwrite the version fails with self.assertRaises(sa.exc.IntegrityError): - write_version_info(version_table, -3) + write_version_info(self.engine, version_table, -3) def test_finder_checks_version(self): version_table = self.metadata.tables['version_info'] version_table.delete().execute() - write_version_info(version_table, -2) - check_version_info(version_table, -2) + write_version_info(self.engine, version_table, -2) + check_version_info(self.engine, version_table, -2) # Assert that trying to build a finder with a bad db raises an error with self.assertRaises(AssetDBVersionError): @@ -1304,8 +1373,8 @@ def test_finder_checks_version(self): # Change the version number of the db to the correct version version_table.delete().execute() - write_version_info(version_table, ASSET_DB_VERSION) - check_version_info(version_table, ASSET_DB_VERSION) + write_version_info(self.engine, version_table, ASSET_DB_VERSION) + check_version_info(self.engine, version_table, ASSET_DB_VERSION) # Now that the versions match, this Finder should succeed AssetFinder(engine=self.engine) @@ -1318,7 +1387,7 @@ def test_downgrade(self): downgrade(self.engine, 3) metadata = sa.MetaData(conn) metadata.reflect(bind=self.engine) - check_version_info(metadata.tables['version_info'], 3) + check_version_info(conn, metadata.tables['version_info'], 3) self.assertFalse('exchange_full' in metadata.tables) # now go all the way to v0 @@ -1328,7 +1397,7 @@ def test_downgrade(self): metadata = sa.MetaData(conn) metadata.reflect(bind=self.engine) version_table = metadata.tables['version_info'] - check_version_info(version_table, 0) + check_version_info(self.engine, version_table, 0) # Check some of the v1-to-v0 downgrades self.assertTrue('futures_contracts' in metadata.tables) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index b80a2e9e97..2742643d0b 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -45,20 +45,24 @@ def make_equity_info(cls): return pd.DataFrame.from_dict( { 1: { - "start_date": cls.START_DATE, - "end_date": cls.END_DATE + pd.Timedelta(days=1) + 'symbol': 'A', + 'start_date': cls.START_DATE, + 'end_date': cls.END_DATE + pd.Timedelta(days=1) }, 2: { - "start_date": cls.START_DATE, - "end_date": cls.END_DATE + pd.Timedelta(days=1) + 'symbol': 'B', + 'start_date': cls.START_DATE, + 'end_date': cls.END_DATE + pd.Timedelta(days=1) }, 3: { - "start_date": pd.Timestamp('2006-05-26', tz='utc'), - "end_date": pd.Timestamp('2006-08-09', tz='utc') + 'symbol': 'C', + 'start_date': pd.Timestamp('2006-05-26', tz='utc'), + 'end_date': pd.Timestamp('2006-08-09', tz='utc') }, 4: { - "start_date": cls.START_DATE, - "end_date": cls.END_DATE + pd.Timedelta(days=1) + 'symbol': 'D', + 'start_date': cls.START_DATE, + 'end_date': cls.END_DATE + pd.Timedelta(days=1) }, }, orient='index', diff --git a/zipline/__main__.py b/zipline/__main__.py index fba2087f9c..3dde2b59be 100644 --- a/zipline/__main__.py +++ b/zipline/__main__.py @@ -348,6 +348,9 @@ def bundles(): """List all of the available data bundles. """ for bundle in sorted(bundles_module.bundles.keys()): + if bundle.startswith('.'): + # hide the test data + continue try: ingestions = sorted( (str(bundles_module.from_bundle_ingest_dirname(ing)) diff --git a/zipline/assets/__init__.py b/zipline/assets/__init__.py index cdba9766a4..be455da552 100644 --- a/zipline/assets/__init__.py +++ b/zipline/assets/__init__.py @@ -23,7 +23,6 @@ from .assets import ( AssetFinder, AssetConvertible, - AssetFinderCachedEquities ) from .asset_db_schema import ASSET_DB_VERSION from .asset_writer import AssetDBWriter @@ -35,7 +34,6 @@ 'Equity', 'Future', 'AssetFinder', - 'AssetFinderCachedEquities', 'AssetConvertible', 'make_asset_array', 'CACHE_FILE_TEMPLATE' diff --git a/zipline/assets/_assets.pyx b/zipline/assets/_assets.pyx index 4ac39e72f1..fb636c1d44 100644 --- a/zipline/assets/_assets.pyx +++ b/zipline/assets/_assets.pyx @@ -58,26 +58,35 @@ cdef class Asset: cdef readonly object exchange - def __cinit__(self, - int sid, # sid is required - object symbol="", - object asset_name="", - object start_date=None, - object end_date=None, - object first_traded=None, - object auto_close_date=None, - object exchange="", - *args, - **kwargs): - - self.sid = sid - self.sid_hash = hash(sid) - self.symbol = symbol - self.asset_name = asset_name - self.exchange = exchange - self.start_date = start_date - self.end_date = end_date - self.first_traded = first_traded + _kwargnames = frozenset({ + 'sid', + 'symbol', + 'asset_name', + 'start_date', + 'end_date', + 'first_traded', + 'auto_close_date', + 'exchange', + }) + + def __init__(self, + int sid, # sid is required + object symbol="", + object asset_name="", + object start_date=None, + object end_date=None, + object first_traded=None, + object auto_close_date=None, + object exchange=""): + + self.sid = sid + self.sid_hash = hash(sid) + self.symbol = symbol + self.asset_name = asset_name + self.exchange = exchange + self.start_date = start_date + self.end_date = end_date + self.first_traded = first_traded self.auto_close_date = auto_close_date def __int__(self): @@ -127,9 +136,9 @@ cdef class Asset: def __str__(self): if self.symbol: - return 'Asset(%d [%s])' % (self.sid, self.symbol) + return '%s(%d [%s])' % (type(self).__name__, self.sid, self.symbol) else: - return 'Asset(%d)' % self.sid + return '%s(%d)' % (type(self).__name__, self.sid) def __repr__(self): attrs = ('symbol', 'asset_name', 'exchange', @@ -213,12 +222,6 @@ cdef class Asset: cdef class Equity(Asset): - def __str__(self): - if self.symbol: - return 'Equity(%d [%s])' % (self.sid, self.symbol) - else: - return 'Equity(%d)' % self.sid - def __repr__(self): attrs = ('symbol', 'asset_name', 'exchange', 'start_date', 'end_date', 'first_traded', 'auto_close_date') @@ -270,26 +273,52 @@ cdef class Future(Asset): cdef readonly object tick_size cdef readonly float multiplier - def __cinit__(self, - int sid, # sid is required - object symbol="", - object root_symbol="", - object asset_name="", - object start_date=None, - object end_date=None, - object notice_date=None, - object expiration_date=None, - object auto_close_date=None, - object first_traded=None, - object exchange="", - object tick_size="", - float multiplier=1): - - self.root_symbol = root_symbol - self.notice_date = notice_date + _kwargnames = frozenset({ + 'sid', + 'symbol', + 'root_symbol', + 'asset_name', + 'start_date', + 'end_date', + 'notice_date', + 'expiration_date', + 'auto_close_date', + 'first_traded', + 'exchange', + 'tick_size', + 'multiplier', + }) + + def __init__(self, + int sid, # sid is required + object symbol="", + object root_symbol="", + object asset_name="", + object start_date=None, + object end_date=None, + object notice_date=None, + object expiration_date=None, + object auto_close_date=None, + object first_traded=None, + object exchange="", + object tick_size="", + float multiplier=1.0): + + super().__init__( + sid, + symbol=symbol, + asset_name=asset_name, + start_date=start_date, + end_date=end_date, + first_traded=first_traded, + auto_close_date=auto_close_date, + exchange=exchange, + ) + self.root_symbol = root_symbol + self.notice_date = notice_date self.expiration_date = expiration_date - self.tick_size = tick_size - self.multiplier = multiplier + self.tick_size = tick_size + self.multiplier = multiplier if auto_close_date is None: if notice_date is None: @@ -299,12 +328,6 @@ cdef class Future(Asset): else: self.auto_close_date = min(notice_date, expiration_date) - def __str__(self): - if self.symbol: - return 'Future(%d [%s])' % (self.sid, self.symbol) - else: - return 'Future(%d)' % self.sid - def __repr__(self): attrs = ('symbol', 'root_symbol', 'asset_name', 'exchange', 'start_date', 'end_date', 'first_traded', 'notice_date', diff --git a/zipline/assets/asset_db_migrations.py b/zipline/assets/asset_db_migrations.py index 22d2180df6..2c56badd24 100644 --- a/zipline/assets/asset_db_migrations.py +++ b/zipline/assets/asset_db_migrations.py @@ -50,7 +50,7 @@ def downgrade(engine, desired_version): # Execute the downgrades in order for downgrade_key in downgrade_keys: - _downgrade_methods[downgrade_key](op, version_info_table) + _downgrade_methods[downgrade_key](op, engine, version_info_table) # Re-enable foreign keys _pragma_foreign_keys(conn, True) @@ -96,10 +96,10 @@ def _(f): @do(op.setitem(_downgrade_methods, destination)) @wraps(f) - def wrapper(op, version_info_table): + def wrapper(op, engine, version_info_table): version_info_table.delete().execute() # clear the version f(op) - write_version_info(version_info_table, destination) + write_version_info(engine, version_info_table, destination) return wrapper return _ @@ -227,3 +227,74 @@ def _downgrade_v4(op): op.create_index('ix_equities_company_symbol', table_name='equities', columns=['company_symbol']) + + +@downgrades(5) +def _downgrade_v5(op): + op.create_table( + '_new_equities', + sa.Column( + 'sid', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('symbol', sa.Text), + sa.Column('company_symbol', sa.Text), + sa.Column('share_class_symbol', sa.Text), + sa.Column('fuzzy_symbol', sa.Text), + sa.Column('asset_name', sa.Text), + sa.Column('start_date', sa.Integer, default=0, nullable=False), + sa.Column('end_date', sa.Integer, nullable=False), + sa.Column('first_traded', sa.Integer), + sa.Column('auto_close_date', sa.Integer), + sa.Column('exchange', sa.Text), + sa.Column('exchange_full', sa.Text) + ) + + op.execute( + """ + insert into _new_equities + select + equities.sid as sid, + sym.symbol as symbol, + sym.company_symbol as company_symbol, + sym.share_class_symbol as share_class_symbol, + sym.company_symbol || sym.share_class_symbol as fuzzy_symbol, + equities.asset_name as asset_name, + equities.start_date as start_date, + equities.end_date as end_date, + equities.first_traded as first_traded, + equities.auto_close_date as auto_close_date, + equities.exchange as exchange, + equities.exchange_full as exchange_full + from + equities + inner join + (select + * + from + equity_symbol_mappings + group by + equity_symbol_mappings.sid + order by + equity_symbol_mappings.end_date desc) sym + on + equities.sid == sym.sid + """, + ) + op.drop_table('equity_symbol_mappings') + op.drop_table('equities') + op.rename_table('_new_equities', 'equities') + # we need to make sure the indicies have the proper names after the rename + op.create_index( + 'ix_equities_company_symbol', + 'equities', + ['company_symbol'], + ) + op.create_index( + 'ix_equities_fuzzy_symbol', + 'equities', + ['fuzzy_symbol'], + ) diff --git a/zipline/assets/asset_db_schema.py b/zipline/assets/asset_db_schema.py index 6499c45f4b..ea54479707 100644 --- a/zipline/assets/asset_db_schema.py +++ b/zipline/assets/asset_db_schema.py @@ -6,167 +6,178 @@ # assets database # NOTE: When upgrading this remember to add a downgrade in: # .asset_db_migrations -ASSET_DB_VERSION = 4 - - -def generate_asset_db_metadata(bind=None): - # NOTE: When modifying this schema, update the ASSET_DB_VERSION value - metadata = sa.MetaData(bind=bind) - _version_table_schema(metadata) - _equities_table_schema(metadata) - _futures_exchanges_schema(metadata) - _futures_root_symbols_schema(metadata) - _futures_contracts_schema(metadata) - _asset_router_schema(metadata) - return metadata - +ASSET_DB_VERSION = 5 # A frozenset of the names of all tables in the assets db # NOTE: When modifying this schema, update the ASSET_DB_VERSION value asset_db_table_names = frozenset({ 'asset_router', 'equities', + 'equity_symbol_mappings', 'futures_contracts', 'futures_exchanges', 'futures_root_symbols', 'version_info', }) +metadata = sa.MetaData() -def _equities_table_schema(metadata): - # NOTE: When modifying this schema, update the ASSET_DB_VERSION value - return sa.Table( - 'equities', - metadata, - sa.Column( - 'sid', - sa.Integer, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('symbol', sa.Text), - sa.Column('company_symbol', sa.Text, index=True), - sa.Column('share_class_symbol', sa.Text), - sa.Column('fuzzy_symbol', sa.Text, index=True), - sa.Column('asset_name', sa.Text), - sa.Column('start_date', sa.Integer, default=0, nullable=False), - sa.Column('end_date', sa.Integer, nullable=False), - sa.Column('first_traded', sa.Integer), - sa.Column('auto_close_date', sa.Integer), - sa.Column('exchange', sa.Text), - sa.Column('exchange_full', sa.Text) - ) - - -def _futures_exchanges_schema(metadata): - # NOTE: When modifying this schema, update the ASSET_DB_VERSION value - return sa.Table( - 'futures_exchanges', - metadata, - sa.Column( - 'exchange', - sa.Text, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('timezone', sa.Text), - ) - - -def _futures_root_symbols_schema(metadata): - # NOTE: When modifying this schema, update the ASSET_DB_VERSION value - return sa.Table( - 'futures_root_symbols', - metadata, - sa.Column( - 'root_symbol', - sa.Text, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('root_symbol_id', sa.Integer), - sa.Column('sector', sa.Text), - sa.Column('description', sa.Text), - sa.Column( - 'exchange', - sa.Text, - sa.ForeignKey('futures_exchanges.exchange'), - ), - ) - - -def _futures_contracts_schema(metadata): - # NOTE: When modifying this schema, update the ASSET_DB_VERSION value - return sa.Table( - 'futures_contracts', - metadata, - sa.Column( - 'sid', - sa.Integer, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('symbol', sa.Text, unique=True, index=True), - sa.Column( - 'root_symbol', - sa.Text, - sa.ForeignKey('futures_root_symbols.root_symbol'), - index=True - ), - sa.Column('asset_name', sa.Text), - sa.Column('start_date', sa.Integer, default=0, nullable=False), - sa.Column('end_date', sa.Integer, nullable=False), - sa.Column('first_traded', sa.Integer), - sa.Column( - 'exchange', - sa.Text, - sa.ForeignKey('futures_exchanges.exchange'), - ), - sa.Column('notice_date', sa.Integer, nullable=False), - sa.Column('expiration_date', sa.Integer, nullable=False), - sa.Column('auto_close_date', sa.Integer, nullable=False), - sa.Column('multiplier', sa.Float), - sa.Column('tick_size', sa.Float), - ) - - -def _asset_router_schema(metadata): - # NOTE: When modifying this schema, update the ASSET_DB_VERSION value - return sa.Table( - 'asset_router', - metadata, - sa.Column( - 'sid', - sa.Integer, - unique=True, - nullable=False, - primary_key=True), - sa.Column('asset_type', sa.Text), - ) - - -def _version_table_schema(metadata): - # NOTE: When modifying this schema, update the ASSET_DB_VERSION value - return sa.Table( - 'version_info', - metadata, - sa.Column( - 'id', - sa.Integer, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column( - 'version', - sa.Integer, - unique=True, - nullable=False, - ), - # This constraint ensures a single entry in this table - sa.CheckConstraint('id <= 1'), - ) +equities = sa.Table( + 'equities', + metadata, + sa.Column( + 'sid', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('asset_name', sa.Text), + sa.Column('start_date', sa.Integer, default=0, nullable=False), + sa.Column('end_date', sa.Integer, nullable=False), + sa.Column('first_traded', sa.Integer), + sa.Column('auto_close_date', sa.Integer), + sa.Column('exchange', sa.Text), + sa.Column('exchange_full', sa.Text) +) + +equity_symbol_mappings = sa.Table( + 'equity_symbol_mappings', + metadata, + sa.Column( + 'id', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column( + 'sid', + sa.Integer, + sa.ForeignKey(equities.c.sid), + nullable=False, + index=True, + ), + sa.Column( + 'symbol', + sa.Text, + nullable=False, + ), + sa.Column( + 'company_symbol', + sa.Text, + index=True, + ), + sa.Column( + 'share_class_symbol', + sa.Text, + ), + sa.Column( + 'start_date', + sa.Integer, + nullable=False, + ), + sa.Column( + 'end_date', + sa.Integer, + nullable=False, + ), +) + +futures_exchanges = sa.Table( + 'futures_exchanges', + metadata, + sa.Column( + 'exchange', + sa.Text, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('timezone', sa.Text), +) + +futures_root_symbols = sa.Table( + 'futures_root_symbols', + metadata, + sa.Column( + 'root_symbol', + sa.Text, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('root_symbol_id', sa.Integer), + sa.Column('sector', sa.Text), + sa.Column('description', sa.Text), + sa.Column( + 'exchange', + sa.Text, + sa.ForeignKey('futures_exchanges.exchange'), + ), +) + +futures_contracts = sa.Table( + 'futures_contracts', + metadata, + sa.Column( + 'sid', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('symbol', sa.Text, unique=True, index=True), + sa.Column( + 'root_symbol', + sa.Text, + sa.ForeignKey('futures_root_symbols.root_symbol'), + index=True + ), + sa.Column('asset_name', sa.Text), + sa.Column('start_date', sa.Integer, default=0, nullable=False), + sa.Column('end_date', sa.Integer, nullable=False), + sa.Column('first_traded', sa.Integer), + sa.Column( + 'exchange', + sa.Text, + sa.ForeignKey('futures_exchanges.exchange'), + ), + sa.Column('notice_date', sa.Integer, nullable=False), + sa.Column('expiration_date', sa.Integer, nullable=False), + sa.Column('auto_close_date', sa.Integer, nullable=False), + sa.Column('multiplier', sa.Float), + sa.Column('tick_size', sa.Float), +) + +asset_router = sa.Table( + 'asset_router', + metadata, + sa.Column( + 'sid', + sa.Integer, + unique=True, + nullable=False, + primary_key=True), + sa.Column('asset_type', sa.Text), +) + +version_info = sa.Table( + 'version_info', + metadata, + sa.Column( + 'id', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column( + 'version', + sa.Integer, + unique=True, + nullable=False, + ), + # This constraint ensures a single entry in this table + sa.CheckConstraint('id <= 1'), +) diff --git a/zipline/assets/asset_writer.py b/zipline/assets/asset_writer.py index b125916923..0fa6dc981c 100644 --- a/zipline/assets/asset_writer.py +++ b/zipline/assets/asset_writer.py @@ -23,16 +23,40 @@ from zipline.errors import AssetDBVersionError from zipline.assets.asset_db_schema import ( - generate_asset_db_metadata, - asset_db_table_names, ASSET_DB_VERSION, + asset_db_table_names, + asset_router, + equities as equities_table, + equity_symbol_mappings, + futures_contracts as futures_contracts_table, + futures_exchanges, + futures_root_symbols, + metadata, + version_info, ) +from zipline.utils.range import from_tuple, intersecting_ranges + # Define a namedtuple for use with the load_data and _load_data methods -AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols') +AssetData = namedtuple( + 'AssetData', ( + 'equities', + 'equities_mappings', + 'futures', + 'exchanges', + 'root_symbols', + ), +) SQLITE_MAX_VARIABLE_NUMBER = 999 +symbol_columns = frozenset({ + 'symbol', + 'company_symbol', + 'share_class_symbol', +}) +mapping_columns = symbol_columns | {'start_date', 'end_date'} + # Default values for the equities DataFrame _equities_defaults = { 'symbol': None, @@ -77,7 +101,7 @@ } # Fuzzy symbol delimiters that may break up a company symbol and share class -_delimited_symbol_delimiter_regex = r'[./\-_]' +_delimited_symbol_delimiters_regex = re.compile(r'[./\-_]') _delimited_symbol_default_triggers = frozenset({np.nan, None, ''}) @@ -94,16 +118,22 @@ def split_delimited_symbol(symbol): Returns ------- - ( str, str , str ) - A tuple of ( company_symbol, share_class_symbol, fuzzy_symbol) + company_symbol : str + The company part of the symbol. + share_class_symbol : str + The share class part of a symbol. """ # return blank strings for any bad fuzzy symbols, like NaN or None if symbol in _delimited_symbol_default_triggers: - return ('', '', '') + return '', '' - split_list = re.split(pattern=_delimited_symbol_delimiter_regex, - string=symbol, - maxsplit=1) + symbol = symbol.upper() + + split_list = re.split( + pattern=_delimited_symbol_delimiters_regex, + string=symbol, + maxsplit=1, + ) # Break the list up in to its two components, the company symbol and the # share class symbol @@ -113,12 +143,7 @@ def split_delimited_symbol(symbol): else: share_class_symbol = '' - # Strip all fuzzy characters from the symbol to get the fuzzy symbol - fuzzy_symbol = re.sub(pattern=_delimited_symbol_delimiter_regex, - repl='', - string=symbol) - - return (company_symbol, share_class_symbol, fuzzy_symbol) + return company_symbol, share_class_symbol def _generate_output_dataframe(data_subset, defaults): @@ -154,19 +179,83 @@ def _generate_output_dataframe(data_subset, defaults): # Get those columns which we need but # for which no data has been supplied. - need = desired_cols - cols - - # Combine the users supplied data with our required columns. - output = pd.concat( - (data_subset, pd.DataFrame( - {k: defaults[k] for k in need}, - data_subset.index, - )), - axis=1, - copy=False + for col in desired_cols - cols: + # write the default value for any missing columns + data_subset[col] = defaults[col] + + return data_subset + + +def _check_asset_group(group): + row = group.iloc[0] + row.start_date = group.start_date.min() + row.end_date = group.end_date.max() + row.drop(list(symbol_columns), inplace=True) + return row + + +def _format_range(r): + return ( + str(pd.Timestamp(r.start, unit='ns')), + str(pd.Timestamp(r.stop, unit='ns')), ) - return output + +def _split_symbol_mappings(df): + """Split out the symbol: sid mappings from the raw data. + + Parameters + ---------- + df : pd.DataFrame + The dataframe with multiple rows for each symbol: sid pair. + + Returns + ------- + asset_info : pd.DataFrame + The asset info with one row per asset. + symbol_mappings : pd.DataFrame + The dataframe of just symbol: sid mappings. The index will be + the sid, then there will be three columns: symbol, start_date, and + end_date. + """ + mappings = df[list(mapping_columns)] + ambigious = {} + for symbol in mappings.symbol.unique(): + persymbol = mappings[mappings.symbol == symbol] + intersections = list(intersecting_ranges(map( + from_tuple, + zip(persymbol.start_date, persymbol.end_date), + ))) + if intersections: + ambigious[symbol] = ( + intersections, + persymbol[['start_date', 'end_date']].astype('datetime64[ns]'), + ) + + if ambigious: + raise ValueError( + 'Ambiguous ownership for %d symbol%s, multiple assets held the' + ' following symbols:\n%s' % ( + len(ambigious), + '' if len(ambigious) == 1 else 's', + '\n'.join( + '%s:\n intersections: %s\n %s' % ( + symbol, + tuple(map(_format_range, intersections)), + # indent the dataframe string + '\n '.join(str(df).splitlines()), + ) + for symbol, (intersections, df) in sorted( + ambigious.items(), + key=first, + ), + ), + ) + ) + return ( + df.groupby(level=0).apply(_check_asset_group), + df[list(mapping_columns)], + ) def _dt_to_epoch_ns(dt_series): @@ -190,12 +279,14 @@ def _dt_to_epoch_ns(dt_series): return index.view(np.int64) -def check_version_info(version_table, expected_version): +def check_version_info(conn, version_table, expected_version): """ Checks for a version value in the version table. Parameters ---------- + conn : sa.Connection + The connection to use to perform the check. version_table : sa.Table The version table of the asset database expected_version : int @@ -208,7 +299,9 @@ def check_version_info(version_table, expected_version): """ # Read the version out of the table - version_from_table = sa.select((version_table.c.version,)).scalar() + version_from_table = conn.execute( + sa.select((version_table.c.version,)), + ).scalar() # A db without a version is considered v0 if version_from_table is None: @@ -220,19 +313,21 @@ def check_version_info(version_table, expected_version): expected_version=expected_version) -def write_version_info(version_table, version_value): +def write_version_info(conn, version_table, version_value): """ Inserts the version value in to the version table. Parameters ---------- + conn : sa.Connection + The connection to use to execute the insert. version_table : sa.Table The version table of the asset database version_value : int The version to write in to the database """ - sa.insert(version_table, values={'version': version_value}).execute() + conn.execute(sa.insert(version_table, values={'version': version_value})) class _empty(object): @@ -269,9 +364,6 @@ def write(self, symbol : str The ticker symbol for this equity. - fuzzy_symbol : str, optional - The fuzzy symbol for this equity. This is the symbol - without any delimiting characters like '.' or '_'. asset_name : str The full name for this asset. start_date : datetime @@ -351,7 +443,7 @@ def write(self, """ with self.engine.begin() as txn: # Create SQL tables if they do not exist. - metadata = self.init_db(txn) + self.init_db(txn) # Get the data to add to SQL. data = self._load_data( @@ -362,51 +454,74 @@ def write(self, ) # Write the data to SQL. self._write_df_to_table( - metadata.tables['futures_exchanges'], + futures_exchanges, data.exchanges, txn, chunk_size, ) self._write_df_to_table( - metadata.tables['futures_root_symbols'], + futures_root_symbols, data.root_symbols, txn, chunk_size, ) - asset_router = metadata.tables['asset_router'] self._write_assets( - asset_router, - metadata.tables['futures_contracts'], 'future', data.futures, txn, chunk_size, ) self._write_assets( - asset_router, - metadata.tables['equities'], 'equity', data.equities, txn, chunk_size, + mapping_data=data.equities_mappings, ) - def _write_df_to_table(self, tbl, df, txn, chunk_size): + def _write_df_to_table(self, tbl, df, txn, chunk_size, idx_label=None): df.to_sql( tbl.name, txn.connection, - index_label=first(tbl.primary_key.columns).name, + index_label=( + idx_label + if idx_label is not None else + first(tbl.primary_key.columns).name + ), if_exists='append', chunksize=chunk_size, ) def _write_assets(self, - asset_router, - tbl, asset_type, assets, txn, - chunk_size): + chunk_size, + mapping_data=None): + if asset_type == 'future': + tbl = futures_contracts_table + if mapping_data is not None: + raise TypeError('no mapping data expected for futures') + + elif asset_type == 'equity': + tbl = equities_table + if mapping_data is None: + raise TypeError('mapping data required for equities') + # write the symbol mapping data. + self._write_df_to_table( + equity_symbol_mappings, + mapping_data, + txn, + chunk_size, + idx_label='sid', + ) + + else: + raise ValueError( + "asset_type must be in {'future', 'equity'}, got: %s" % + asset_type, + ) + self._write_df_to_table(tbl, assets, txn, chunk_size) pd.DataFrame({ @@ -459,17 +574,14 @@ def init_db(self, txn=None): txn = stack.enter_context(self.engine.begin()) tables_already_exist = self._all_tables_present(txn) - metadata = generate_asset_db_metadata(bind=txn) # Create the SQL tables if they do not already exist. - metadata.create_all(checkfirst=True) + metadata.create_all(txn, checkfirst=True) - version_info = metadata.tables['version_info'] if tables_already_exist: - check_version_info(version_info, ASSET_DB_VERSION) + check_version_info(txn, version_info, ASSET_DB_VERSION) else: - write_version_info(version_info, ASSET_DB_VERSION) - return metadata + write_version_info(txn, version_info, ASSET_DB_VERSION) def _normalize_equities(self, equities): # HACK: If 'company_name' is provided, map it to asset_name @@ -490,16 +602,13 @@ def _normalize_equities(self, equities): tuple_series = equities_output['symbol'].apply(split_delimited_symbol) split_symbols = pd.DataFrame( tuple_series.tolist(), - columns=['company_symbol', 'share_class_symbol', 'fuzzy_symbol'], + columns=['company_symbol', 'share_class_symbol'], index=tuple_series.index ) - equities_output = equities_output.join(split_symbols) + equities_output = pd.concat((equities_output, split_symbols), axis=1) # Upper-case all symbol data - for col in ('symbol', - 'company_symbol', - 'share_class_symbol', - 'fuzzy_symbol'): + for col in symbol_columns: equities_output[col] = equities_output[col].str.upper() # Convert date columns to UNIX Epoch integers (nanoseconds) @@ -509,7 +618,7 @@ def _normalize_equities(self, equities): 'auto_close_date'): equities_output[col] = _dt_to_epoch_ns(equities_output[col]) - return equities_output + return _split_symbol_mappings(equities_output) def _normalize_futures(self, futures): futures_output = _generate_output_dataframe( @@ -544,7 +653,7 @@ def _load_data(self, equities, futures, exchanges, root_symbols): if id_col in df.columns: df.set_index(id_col, inplace=True) - equities_output = self._normalize_equities(equities) + equities_output, equities_mappings = self._normalize_equities(equities) futures_output = self._normalize_futures(futures) exchanges_output = _generate_output_dataframe( @@ -559,6 +668,7 @@ def _load_data(self, equities, futures, exchanges, root_symbols): return AssetData( equities=equities_output, + equities_mappings=equities_mappings, futures=futures_output, exchanges=exchanges_output, root_symbols=root_symbols_output, diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index f9ffc93449..4796ac8d4a 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -13,16 +13,18 @@ # limitations under the License. from abc import ABCMeta +from collections import namedtuple from numbers import Integral -from operator import itemgetter +from operator import itemgetter, attrgetter from logbook import Logger import numpy as np import pandas as pd from pandas import isnull -from six import with_metaclass, string_types, viewkeys -from six.moves import map as imap +from six import with_metaclass, string_types, viewkeys, iteritems import sqlalchemy as sa +from toolz import merge, compose, valmap, sliding_window, concatv, curry +from toolz.curried import operator as op from zipline.errors import ( EquitiesNotFound, @@ -33,18 +35,20 @@ SidsNotFound, SymbolNotFound, ) -from zipline.assets import ( +from . import ( Asset, Equity, Future, ) -from zipline.assets.asset_writer import ( +from .asset_writer import ( check_version_info, split_delimited_symbol, asset_db_table_names, + symbol_columns, ) -from zipline.assets.asset_db_schema import ( +from .asset_db_schema import ( ASSET_DB_VERSION ) from zipline.utils.control_flow import invert +from zipline.utils.memoize import lazyval from zipline.utils.sqlite_utils import group_into_chunks log = Logger('assets.py') @@ -67,12 +71,38 @@ 'auto_close_date', }) +SymbolOwnership = namedtuple('SymbolOwnership', 'start end sid symbol') + + +@curry +def _filter_kwargs(names, dict_): + """Filter out kwargs from a dictionary. + + Parameters + ---------- + names : set[str] + The names to select from ``dict_``. + dict_ : dict[str, any] + The dictionary to select from. + + Returns + ------- + kwargs : dict[str, any] + ``dict_`` where the keys intersect with ``names`` and the values are + not None. + """ + return {k: v for k, v in dict_.items() if k in names and v is not None} + + +_filter_future_kwargs = _filter_kwargs(Future._kwargnames) +_filter_equity_kwargs = _filter_kwargs(Equity._kwargnames) + def _convert_asset_timestamp_fields(dict_): """ Takes in a dict of Asset init args and converts dates to pd.Timestamps """ - for key in (_asset_timestamp_fields & viewkeys(dict_)): + for key in _asset_timestamp_fields & viewkeys(dict_): value = pd.Timestamp(dict_[key], tz='UTC') dict_[key] = None if isnull(value) else value return dict_ @@ -101,17 +131,18 @@ class AssetFinder(object): PERSISTENT_TOKEN = "" def __init__(self, engine): - if isinstance(engine, string_types): - engine = sa.create_engine('sqlite:///' + engine) - - self.engine = engine + self.engine = engine = ( + sa.create_engine('sqlite:///' + engine) + if isinstance(engine, string_types) else + engine + ) metadata = sa.MetaData(bind=engine) metadata.reflect(only=asset_db_table_names) for table_name in asset_db_table_names: setattr(self, table_name, metadata.tables[table_name]) # Check the version info of the db for compatibility - check_version_info(self.version_info, ASSET_DB_VERSION) + check_version_info(engine, self.version_info, ASSET_DB_VERSION) # Cache for lookup of assets by sid, the objects in the asset lookup # may be shared with the results from equity and future lookup caches. @@ -137,6 +168,79 @@ def _reset_caches(self): # should be calling this. for cache in self._caches: cache.clear() + self.reload_symbol_maps() + + def reload_symbol_maps(self): + """Clear the in memory symbol lookup maps. + + This will make any changes to the underlying db available to the + symbol maps. + """ + # clear the lazyval caches, the next access will requery + try: + del type(self).symbol_ownership_map[self] + except KeyError: + pass + try: + del type(self).fuzzy_symbol_ownership_map[self] + except KeyError: + pass + + @lazyval + def symbol_ownership_map(self): + rows = sa.select(self.equity_symbol_mappings.c).execute().fetchall() + + mappings = {} + for row in rows: + mappings.setdefault( + (row.company_symbol, row.share_class_symbol), + [], + ).append( + SymbolOwnership( + pd.Timestamp(row.start_date, unit='ns', tz='utc'), + pd.Timestamp(row.end_date, unit='ns', tz='utc'), + row.sid, + row.symbol, + ), + ) + + return valmap( + lambda v: tuple( + SymbolOwnership( + a.start, + b.start, + a.sid, + a.symbol, + ) for a, b in sliding_window( + 2, + concatv( + sorted(v), + # concat with a fake ownership object to make the last + # end date be max timestamp + [SymbolOwnership( + pd.Timestamp.max.tz_localize('utc'), + None, + None, + None, + )], + ), + ) + ), + mappings, + factory=lambda: mappings, + ) + + @lazyval + def fuzzy_symbol_ownership_map(self): + fuzzy_mappings = {} + for (cs, scs), owners in iteritems(self.symbol_ownership_map): + fuzzy_owners = fuzzy_mappings.setdefault( + cs + scs, + [], + ) + fuzzy_owners.extend(owners) + fuzzy_owners.sort() + return fuzzy_mappings def lookup_asset_types(self, sids): """ @@ -326,6 +430,50 @@ def _select_assets_by_sid(asset_tbl, sids): def _select_asset_by_symbol(asset_tbl, symbol): return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol) + def _lookup_most_recent_symbols(self, sids): + symbol_cols = self.equity_symbol_mappings.c + + symbols = { + row.sid: {c: row[c] for c in symbol_columns} + for row in self.engine.execute( + sa.select( + (symbol_cols.sid,) + + tuple(map(op.getitem(symbol_cols), symbol_columns)), + ).where( + symbol_cols.sid.in_(map(int, sids)), + ).order_by( + symbol_cols.end_date.desc(), + ).group_by( + symbol_cols.sid, + ) + ).fetchall() + } + + if len(symbols) != len(sids): + raise EquitiesNotFound( + sids=set(sids) - set(symbols), + plural=True, + ) + return symbols + + def _retrieve_asset_dicts(self, sids, asset_tbl, querying_equities): + if not sids: + return + + if querying_equities: + def mkdict(row, + symbols=self._lookup_most_recent_symbols(sids)): + return merge(row, symbols[row['sid']]) + else: + mkdict = dict + + for assets in group_into_chunks(sids): + # Load misses from the db. + query = self._select_assets_by_sid(asset_tbl, assets) + + for row in query.execute().fetchall(): + yield _convert_asset_timestamp_fields(mkdict(row)) + def _retrieve_assets(self, sids, asset_tbl, asset_type): """ Internal function for loading assets from a table. @@ -354,14 +502,18 @@ def _retrieve_assets(self, sids, asset_tbl, asset_type): cache = self._asset_cache hits = {} - for assets in group_into_chunks(sids): - # Load misses from the db. - query = self._select_assets_by_sid(asset_tbl, assets) + querying_equities = issubclass(asset_type, Equity) + filter_kwargs = ( + _filter_equity_kwargs + if querying_equities else + _filter_future_kwargs + ) - for row in imap(dict, query.execute().fetchall()): - asset = asset_type(**_convert_asset_timestamp_fields(row)) - sid = asset.sid - hits[sid] = cache[sid] = asset + rows = self._retrieve_asset_dicts(sids, asset_tbl, querying_equities) + for row in rows: + sid = row['sid'] + asset = asset_type(**filter_kwargs(row)) + hits[sid] = cache[sid] = asset # If we get here, it means something in our code thought that a # particular sid was an equity/future and called this function with a @@ -369,166 +521,152 @@ def _retrieve_assets(self, sids, asset_tbl, asset_type): # an error in our code, not a user-input error. misses = tuple(set(sids) - viewkeys(hits)) if misses: - if asset_type == Equity: + if querying_equities: raise EquitiesNotFound(sids=misses) else: raise FutureContractsNotFound(sids=misses) return hits - def _get_fuzzy_candidates(self, fuzzy_symbol): - candidates = sa.select( - (self.equities.c.sid,) - ).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by( - self.equities.c.start_date.desc(), - self.equities.c.end_date.desc() - ).execute().fetchall() - return candidates - - def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value): - candidates = sa.select( - (self.equities.c.sid,) - ).where( - sa.and_( - self.equities.c.fuzzy_symbol == fuzzy_symbol, - self.equities.c.start_date <= ad_value, - self.equities.c.end_date >= ad_value - ) - ).order_by( - self.equities.c.start_date.desc(), - self.equities.c.end_date.desc(), - ).execute().fetchall() - return candidates - - def _get_split_candidates_in_range(self, - company_symbol, - share_class_symbol, - ad_value): - candidates = sa.select( - (self.equities.c.sid,) - ).where( - sa.and_( - self.equities.c.company_symbol == company_symbol, - self.equities.c.share_class_symbol == share_class_symbol, - self.equities.c.start_date <= ad_value, - self.equities.c.end_date >= ad_value - ) - ).order_by( - self.equities.c.start_date.desc(), - self.equities.c.end_date.desc(), - ).execute().fetchall() - return candidates - - def _get_split_candidates(self, company_symbol, share_class_symbol): - candidates = sa.select( - (self.equities.c.sid,) - ).where( - sa.and_( - self.equities.c.company_symbol == company_symbol, - self.equities.c.share_class_symbol == share_class_symbol - ) - ).order_by( - self.equities.c.start_date.desc(), - self.equities.c.end_date.desc(), - ).execute().fetchall() - return candidates - - def _resolve_no_matching_candidates(self, - company_symbol, - share_class_symbol, - ad_value): - candidates = sa.select((self.equities.c.sid,)).where( - sa.and_( - self.equities.c.company_symbol == company_symbol, - self.equities.c.share_class_symbol == + def _lookup_symbol_strict(self, symbol, as_of_date): + # split the symbol into the components, if there are no + # company/share class parts then share_class_symbol will be empty + company_symbol, share_class_symbol = split_delimited_symbol(symbol) + try: + owners = self.symbol_ownership_map[ + company_symbol, share_class_symbol, - self.equities.c.start_date <= ad_value), - ).order_by( - self.equities.c.end_date.desc(), - ).execute().fetchall() - return candidates + ] + assert owners, 'empty owners list for %r' % symbol + except KeyError: + # no equity has ever held this symbol + raise SymbolNotFound(symbol=symbol) + + if not as_of_date: + if len(owners) > 1: + # more than one equity has held this ticker, this is ambigious + # without the date + raise MultipleSymbolsFound( + symbol=symbol, + options=set(map( + compose(self.retrieve_asset, attrgetter('sid')), + owners, + )), + ) - def _get_best_candidate(self, candidates): - return self._retrieve_equity(candidates[0]['sid']) + # exactly one equity has ever held this symbol, we may resolve + # without the date + return self.retrieve_asset(owners[0].sid) - def _get_equities_from_candidates(self, candidates): - sids = map(itemgetter('sid'), candidates) - results = self.retrieve_equities(sids) - return [results[sid] for sid in sids] + for start, end, sid, _ in owners: + if start <= as_of_date < end: + # find the equity that owned it on the given asof date + return self.retrieve_asset(sid) - def lookup_symbol(self, symbol, as_of_date, fuzzy=False): - """ - Return matching Equity of name symbol in database. + # no equity held the ticker on the given asof date + raise SymbolNotFound(symbol=symbol) - If multiple Equities are found and as_of_date is not set, - raises MultipleSymbolsFound. + def _lookup_symbol_fuzzy(self, symbol, as_of_date): + symbol = symbol.upper() + company_symbol, share_class_symbol = split_delimited_symbol(symbol) + try: + owners = self.fuzzy_symbol_ownership_map[ + company_symbol + share_class_symbol + ] + assert owners, 'empty owners list for %r' % symbol + except KeyError: + # no equity has ever held a symbol matching the fuzzy symbol + raise SymbolNotFound(symbol=symbol) - If no Equity was active at as_of_date raises SymbolNotFound. - """ - company_symbol, share_class_symbol, fuzzy_symbol = \ - split_delimited_symbol(symbol) - if as_of_date: - # Format inputs - as_of_date = pd.Timestamp(as_of_date).normalize() - ad_value = as_of_date.value - - if fuzzy: - # Search for a single exact match on the fuzzy column - candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol, - ad_value) - - # If exactly one SID exists for fuzzy_symbol, return that sid - if len(candidates) == 1: - return self._get_best_candidate(candidates) - - # Search for exact matches of the split-up company_symbol and - # share_class_symbol - candidates = self._get_split_candidates_in_range( - company_symbol, - share_class_symbol, - ad_value + if not as_of_date: + if len(owners) == 1: + # only one valid match + return self.retrieve_asset(owners[0].sid) + + options = [] + for _, _, sid, sym in owners: + if sym == symbol: + # there are multiple options, look for exact matches + options.append(self.retrieve_asset(sid)) + + if len(options) == 1: + # there was only one exact match + return options[0] + + # there are more than one exact match for this fuzzy symbol + raise MultipleSymbolsFound( + symbol=symbol, + options=set(options), ) - # If exactly one SID exists for symbol, return that symbol - # If multiple SIDs exist for symbol, return latest start_date with - # end_date as a tie-breaker - if candidates: - return self._get_best_candidate(candidates) - - # If no SID exists for symbol, return SID with the - # highest-but-not-over end_date - elif not candidates: - candidates = self._resolve_no_matching_candidates( - company_symbol, - share_class_symbol, - ad_value - ) - if candidates: - return self._get_best_candidate(candidates) + options = [] + for start, end, sid, sym in owners: + if start <= as_of_date < end: + # see which fuzzy symbols were owned on the asof date. + options.append((sid, sym)) + + if not options: + # no equity owned the fuzzy symbol on the date requested + SymbolNotFound(symbol=symbol) + + if len(options) == 1: + # there was only one owner, return it + return self.retrieve_asset(options[0][0]) + + for sid, sym in options: + if sym == symbol: + # look for an exact match on the asof date + return self.retrieve_asset(sid) + + # multiple equities held tickers matching the fuzzy ticker but + # there are no exact matches + raise MultipleSymbolsFound( + symbol=symbol, + options=set(map( + compose(self.retrieve_asset, itemgetter(0)), + options, + )), + ) - raise SymbolNotFound(symbol=symbol) + def lookup_symbol(self, symbol, as_of_date, fuzzy=False): + """Lookup an equity by symbol. - else: - # If this is a fuzzy look-up, check if there is exactly one match - # for the fuzzy symbol - if fuzzy: - candidates = self._get_fuzzy_candidates(fuzzy_symbol) - if len(candidates) == 1: - return self._get_best_candidate(candidates) - - candidates = self._get_split_candidates(company_symbol, - share_class_symbol) - if len(candidates) == 1: - return self._get_best_candidate(candidates) - elif not candidates: - raise SymbolNotFound(symbol=symbol) - else: - raise MultipleSymbolsFound( - symbol=symbol, - options=self._get_equities_from_candidates(candidates) - ) + Parameters + ---------- + symbol : str + The ticker symbol to resolve. + as_of_date : datetime or None + Look up the last owner of this symbol as of this datetime. + If ``as_of_date`` is None, then this can only resolve the equity + if exactly one equity has ever owned the ticker. + fuzzy : bool, optional + Should fuzzy symbol matching be used? Fuzzy symbol matching + attempts to resolve differences in representations for + shareclasses. For example, some people may represent the ``A`` + shareclass of ``BRK`` as ``BRK.A``, where others could write + ``BRK_A``. + + Returns + ------- + equity : Equity + The equity that held ``symbol`` on the given ``as_of_date``, or the + only equity to hold ``symbol`` if ``as_of_date`` is None. + + Raises + ------ + SymbolNotFound + Raised when no equity has ever held the given symbol. + MultipleSymbolsFound + Raised when no ``as_of_date`` is given and more than one equity + has held ``symbol``. This is also raised when ``fuzzy=True`` and + there are multiple candidates for the given ``symbol`` on the + ``as_of_date``. + """ + if fuzzy: + return self._lookup_symbol_fuzzy(symbol, as_of_date) + return self._lookup_symbol_strict(symbol, as_of_date) def lookup_future_symbol(self, symbol): - """ Return the Future object for a given symbol. + """Lookup a future contract by symbol. Parameters ---------- @@ -537,8 +675,8 @@ def lookup_future_symbol(self, symbol): Returns ------- - Future - A Future object. + future : Future + The future contract referenced by ``symbol``. Raises ------ @@ -946,100 +1084,6 @@ class NotAssetConvertible(ValueError): pass -class AssetFinderCachedEquities(AssetFinder): - """ - An extension to AssetFinder that preloads all equities from equities table - into memory and does lookups from there. - - To have any changes in the underlying assets db reflected by this asset - finder one must manually call the ``rehash_equities`` method. - """ - - def __init__(self, engine): - super(AssetFinderCachedEquities, self).__init__(engine) - self._fuzzy_symbol_cache = {} - self._company_share_class_cache = {} - - self.rehash_equities() - - def rehash_equities(self): - """Reload the underlying assets db into the in memory cache. - """ - for equity in sa.select(self.equities.c).execute().fetchall(): - company_symbol = equity['company_symbol'] - share_class_symbol = equity['share_class_symbol'] - fuzzy_symbol = equity['fuzzy_symbol'] - asset = self._convert_row_to_equity(equity) - self._company_share_class_cache.setdefault( - (company_symbol, share_class_symbol), - [] - ).append(asset) - self._fuzzy_symbol_cache.setdefault( - fuzzy_symbol, - [], - ).append(asset) - - def _convert_row_to_equity(self, row): - """ - Converts a SQLAlchemy equity row to an Equity object. - """ - return Equity(**_convert_asset_timestamp_fields(dict(row))) - - def _get_fuzzy_candidates(self, fuzzy_symbol): - return self._fuzzy_symbol_cache.get(fuzzy_symbol, ()) - - def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value): - return only_active_assets( - ad_value, - self._get_fuzzy_candidates(fuzzy_symbol), - ) - - def _get_split_candidates(self, company_symbol, share_class_symbol): - return self._company_share_class_cache.get( - (company_symbol, share_class_symbol), - (), - ) - - def _get_split_candidates_in_range(self, - company_symbol, - share_class_symbol, - ad_value): - return sorted( - only_active_assets( - ad_value, - self._get_split_candidates(company_symbol, share_class_symbol), - ), - key=lambda x: (x.start_date, x.end_date), - reverse=True, - ) - - def _resolve_no_matching_candidates(self, - company_symbol, - share_class_symbol, - ad_value): - equities = self._get_split_candidates( - company_symbol, - share_class_symbol - ) - partial_candidates = [] - for equity in equities: - if equity.start_date.value <= ad_value: - partial_candidates.append(equity) - if partial_candidates: - partial_candidates = sorted( - partial_candidates, - key=lambda x: x.end_date, - reverse=True - ) - return partial_candidates - - def _get_best_candidate(self, candidates): - return candidates[0] - - def _get_equities_from_candidates(self, candidates): - return candidates - - def was_active(reference_date_value, asset): """ Whether or not `asset` was active at the time corresponding to diff --git a/zipline/data/bundles/yahoo.py b/zipline/data/bundles/yahoo.py index 8a50025366..db9ce29840 100644 --- a/zipline/data/bundles/yahoo.py +++ b/zipline/data/bundles/yahoo.py @@ -6,6 +6,7 @@ import requests from zipline.utils.cli import maybe_show_progress +from .core import register def _cachpath(symbol, type_): @@ -169,3 +170,24 @@ def _pricing_iter(): adjustment_writer.write(splits=splits, dividends=dividends) return ingest + + +# bundle used when creating test data +register( + '.test', + yahoo_equities( + ( + 'AMD', + 'CERN', + 'COST', + 'DELL', + 'GPS', + 'INTC', + 'MMM', + 'AAPL', + 'MSFT', + ), + pd.Timestamp('2004-01-02', tz='utc'), + pd.Timestamp('2015-01-01', tz='utc'), + ), +) diff --git a/zipline/errors.py b/zipline/errors.py index 3f2afa2095..84db348f42 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -281,7 +281,7 @@ class MultipleSymbolsFound(ZiplineError): as_of_date' argument to to specify when the date symbol-lookup should be valid. -Possible options:{options} +Possible options: {options} """.strip() diff --git a/zipline/testing/predicates.py b/zipline/testing/predicates.py index e2f4078bc9..4a0f1919af 100644 --- a/zipline/testing/predicates.py +++ b/zipline/testing/predicates.py @@ -36,7 +36,11 @@ ) import numpy as np import pandas as pd -from pandas.util.testing import assert_frame_equal +from pandas.util.testing import ( + assert_frame_equal, + assert_panel_equal, + assert_series_equal, +) from six import iteritems, viewkeys, PY2 from toolz import dissoc, keyfilter import toolz.curried.operator as op @@ -393,18 +397,49 @@ def assert_array_equal(result, raise AssertionError('\n'.join((str(e), _fmt_path(path)))) -@assert_equal.register(pd.DataFrame, pd.DataFrame) -def assert_dataframe_equal(result, expected, path=(), msg='', **kwargs): - try: - assert_frame_equal( - result, - expected, - **filter_kwargs(assert_frame_equal, kwargs) - ) - except AssertionError as e: - raise AssertionError( - _fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))), - ) +def _register_assert_ndframe_equal(type_, assert_eq): + """Register a new check for an ndframe object. + + Parameters + ---------- + type_ : type + The class to register an ``assert_equal`` dispatch for. + assert_eq : callable[type_, type_] + The function which checks that if the two ndframes are equal. + + Returns + ------- + assert_ndframe_equal : callable[type_, type_] + The wrapped function registered with ``assert_equal``. + """ + @assert_equal.register(type_, type_) + def assert_ndframe_equal(result, expected, path=(), msg='', **kwargs): + try: + assert_eq( + result, + expected, + **filter_kwargs(assert_frame_equal, kwargs) + ) + except AssertionError as e: + raise AssertionError( + _fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))), + ) + + return assert_ndframe_equal + + +assert_frame_equal = _register_assert_ndframe_equal( + pd.DataFrame, + assert_frame_equal, +) +assert_panel_equal = _register_assert_ndframe_equal( + pd.Panel, + assert_panel_equal, +) +assert_series_equal = _register_assert_ndframe_equal( + pd.Series, + assert_series_equal, +) @assert_equal.register(Adjustment, Adjustment) diff --git a/zipline/utils/functional.py b/zipline/utils/functional.py index 791be00403..716e8777d1 100644 --- a/zipline/utils/functional.py +++ b/zipline/utils/functional.py @@ -1,8 +1,9 @@ +from functools import reduce from pprint import pformat from six import viewkeys from six.moves import map, zip -from toolz import curry +from toolz import curry, flip @curry @@ -332,17 +333,60 @@ def decorator(f): with_doc = set_attribute('__doc__') -def let(a): - """Box a value to be bound in a for binding. +def foldr(f, seq, default=_no_default): + """Fold a function over a sequence with right associativity. + + Parameters + ---------- + f : callable[any, any] + The function to reduce the sequence with. + The first argument will be the element of the sequence; the second + argument will be the accumulator. + seq : iterable[any] + The sequence to reduce. + default : any, optional + The starting value to reduce with. If not provided, the sequence + cannot be empty, and the last value of the sequence will be used. + + Returns + ------- + folded : any + The folded value. + + Notes + ----- + This functions works by reducing the list in a right associative way. + + For example, imagine we are folding with ``operator.add`` or ``+``: + + .. code-block:: python + + foldr(add, seq) -> seq[0] + (seq[1] + (seq[2] + (...seq[-1], default))) + + In the more general case with an arbitrary function, ``foldr`` will expand + like so: - Examples - -------- .. code-block:: python - [f(y, y) for x in xs for y in let(g(x)) if p(y)] + foldr(f, seq) -> f(seq[0], f(seq[1], f(seq[2], ...f(seq[-1], default)))) - Here, ``y`` is available in both the predicate and the expression - of the comprehension. We can see that this allows us to cache the work - of computing ``g(x)`` even within the expression. + For a more in depth discussion of left and right folds, see: + `https://en.wikipedia.org/wiki/Fold_(higher-order_function)`_ + The images in that page are very good for showing the differences between + ``foldr`` and ``foldl`` (``reduce``). + + .. note:: + + For performance reasons is is best to pass a strict (non-lazy) sequence, + for example, a list. + + See Also + -------- + :func:`functools.reduce` + :func:`sum` """ - return a, + return reduce( + flip(f), + reversed(seq), + *(default,) if default is not _no_default else () + ) diff --git a/zipline/utils/paths.py b/zipline/utils/paths.py index 0e87d94d06..160de9013d 100644 --- a/zipline/utils/paths.py +++ b/zipline/utils/paths.py @@ -19,7 +19,7 @@ def hidden(path): path : str A filepath. """ - return path.startswith('.') + return os.path.split(path)[1].startswith('.') def ensure_directory(path): diff --git a/zipline/utils/range.py b/zipline/utils/range.py new file mode 100644 index 0000000000..138b092200 --- /dev/null +++ b/zipline/utils/range.py @@ -0,0 +1,308 @@ +import operator as op + +from six import PY2 +from toolz import peek + +from zipline.utils.functional import foldr + + +if PY2: + class range(object): + """Lazy range object with constant time containment check. + + The arguments are the same as ``range``. + """ + __slots__ = 'start', 'stop', 'step' + + def __init__(self, stop, *args): + if len(args) > 2: + raise TypeError( + 'range takes at most 3 arguments (%d given)' % len(args) + ) + + if not args: + self.start = 0 + self.stop = stop + self.step = 1 + else: + self.start = stop + self.stop = args[0] + try: + self.step = args[1] + except IndexError: + self.step = 1 + + def __iter__(self): + n = self.start + stop = self.stop + step = self.step + while n < stop: + yield n + n += step + + _ops = ( + (op.gt, op.ge), + (op.le, op.lt), + ) + + def __contains__(self, other, _ops=_ops): + start = self.start + step = self.step + cmp_start, cmp_stop = _ops[step > 0] + return ( + cmp_start(start, other) and + cmp_stop(other, self.stop) and + (other - start) % step == 0 + ) + + del _ops + + def __repr__(self): + return '%s(%s, %s%s)' % ( + type(self).__name__, + self.start, + self.stop, + (', ' + str(self.step)) if self.step != 1 else '', + ) + + def __hash__(self): + return hash((type(self), self.start, self.stop, self.step)) + + def __eq__(self, other): + """ + Examples + -------- + >>> range(1) == range(1) + True + >>> range(0, 5, 2) == range(0, 5, 2) + True + >>> range(5, 0, -2) == range(5, 0, -2) + True + + >>> range(1) == range(2) + False + >>> range(0, 5, 2) == range(0, 5, 3) + False + """ + return all( + getattr(self, attr) == getattr(other, attr) + for attr in self.__slots__ + ) +else: + range = range + + +def from_tuple(tup): + """Convert a tuple into a range with error handling. + + Parameters + ---------- + tup : tuple (len 2 or 3) + The tuple to turn into a range. + + Returns + ------- + range : range + The range from the tuple. + + Raises + ------ + ValueError + Raised when the tuple length is not 2 or 3. + """ + if len(tup) not in (2, 3): + raise ValueError( + 'tuple must contain 2 or 3 elements, not: %d (%r' % ( + len(tup), + tup, + ), + ) + return range(*tup) + + +def maybe_from_tuple(tup_or_range): + """Convert a tuple into a range but pass ranges through silently. + + This is useful to ensure that input is a range so that attributes may + be accessed with `.start`, `.stop` or so that containment checks are + constant time. + + Parameters + ---------- + tup_or_range : tuple or range + A tuple to pass to from_tuple or a range to return. + + Returns + ------- + range : range + The input to convert to a range. + + Raises + ------ + ValueError + Raised when the input is not a tuple or a range. ValueError is also + raised if the input is a tuple whose length is not 2 or 3. + """ + if isinstance(tup_or_range, tuple): + return from_tuple(tup_or_range) + elif isinstance(tup_or_range, range): + return tup_or_range + + raise ValueError( + 'maybe_from_tuple expects a tuple or range, got %r: %r' % ( + type(tup_or_range).__name__, + tup_or_range, + ), + ) + + +def _check_steps(a, b): + """Check that the steps of ``a`` and ``b`` are both 1. + + Parameters + ---------- + a : range + The first range to check. + b : range + The second range to check. + + Raises + ------ + ValueError + Raised when either step is not 1. + """ + if a.step != 1: + raise ValueError('a.step must be equal to 1, got: %s' % a.step) + if b.step != 1: + raise ValueError('b.step must be equal to 1, got: %s' % b.step) + + +def overlap(a, b): + """Check if two ranges overlap. + + Parameters + ---------- + a : range + The first range. + b : range + The second range. + + Returns + ------- + overlaps : bool + Do these ranges overlap. + + Notes + ----- + This function does not support ranges with step != 1. + """ + _check_steps(a, b) + return a.stop >= b.start and b.stop >= a.start + + +def merge(a, b): + """Merge two ranges with step == 1. + + Parameters + ---------- + a : range + The first range. + b : range + The second range. + """ + _check_steps(a, b) + return range(min(a.start, b.start), max(a.stop, b.stop)) + + +def _combine(n, rs): + """helper for ``_group_ranges`` + """ + try: + r, rs = peek(rs) + except StopIteration: + yield n + return + + if overlap(n, r): + yield merge(n, r) + next(rs) + for r in rs: + yield r + else: + yield n + for r in rs: + yield r + + +def group_ranges(ranges): + """Group any overlapping ranges into a single range. + + Parameters + ---------- + ranges : iterable[ranges] + A sorted sequence of ranges to group. + + Returns + ------- + grouped : iterable[ranges] + A sorted sequence of ranges with overlapping ranges merged together. + """ + return foldr(_combine, ranges, ()) + + +def sorted_diff(rs, ss): + try: + r, rs = peek(rs) + except StopIteration: + return + + try: + s, ss = peek(ss) + except StopIteration: + for r in rs: + yield r + return + + rtup = (r.start, r.stop) + stup = (s.start, s.stop) + if rtup == stup: + next(rs) + next(ss) + elif rtup < stup: + yield next(rs) + else: + next(ss) + + for t in sorted_diff(rs, ss): + yield t + + +def intersecting_ranges(ranges): + """Return any ranges that intersect. + + Parameters + ---------- + ranges : iterable[ranges] + A sequence of ranges to check for intersections. + + Returns + ------- + intersections : iterable[ranges] + A sequence of all of the ranges that intersected in ``ranges``. + + Examples + -------- + >>> ranges = [range(0, 1), range(2, 5), range(4, 7)] + >>> list(intersecting_ranges(ranges)) + [range(2, 5), range(4, 7)] + + >>> ranges = [range(0, 1), range(2, 3)] + >>> list(intersecting_ranges(ranges)) + [] + + >>> ranges = [range(0, 1), range(1, 2)] + >>> list(intersecting_ranges(ranges)) + [range(0, 1), range(1, 2)] + """ + ranges = sorted(ranges, key=op.attrgetter('start')) + return sorted_diff(ranges, group_ranges(ranges))