Skip to content

Commit

Permalink
Merge pull request #2609 from quantopian/currency-fixes
Browse files Browse the repository at this point in the history
Currency Improvements
  • Loading branch information
Scott Sanderson committed Jan 10, 2020
2 parents 5e61943 + 09fb188 commit 3ea5916
Show file tree
Hide file tree
Showing 14 changed files with 306 additions and 164 deletions.
47 changes: 40 additions & 7 deletions tests/data/test_daily_bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from toolz import merge
from trading_calendars import get_calendar

from zipline.currency import MISSING_CURRENCY_CODE
from zipline.data.bar_reader import (
NoDataAfterDate,
NoDataBeforeDate,
Expand All @@ -59,7 +60,7 @@
expected_bar_values_2d,
make_bar_data,
)
from zipline.testing import seconds_to_timestamp
from zipline.testing import seconds_to_timestamp, powerset
from zipline.testing.fixtures import (
WithAssetFinder,
WithBcolzEquityDailyBarReader,
Expand Down Expand Up @@ -522,13 +523,45 @@ def test_get_last_traded_dt(self):
)

def test_listing_currency(self):
assets = np.array(list(self.assets))
# TODO: Test loading codes for missing assets.
results = self.daily_bar_reader.currency_codes(assets)
expected = self.make_equity_daily_bar_currency_codes(
self.DAILY_BARS_TEST_QUERY_COUNTRY_CODE, assets,
# Test loading on all assets.
all_assets = np.array(list(self.assets))
all_results = self.daily_bar_reader.currency_codes(all_assets)
all_expected = self.make_equity_daily_bar_currency_codes(
self.DAILY_BARS_TEST_QUERY_COUNTRY_CODE, all_assets,
).values
assert_equal(results, expected)
assert_equal(all_results, all_expected)

# Check all possible subsets of assets.
for indices in map(list, powerset(range(len(all_assets)))):
# Empty queries aren't currently supported.
if not indices:
continue
assets = all_assets[indices]
results = self.daily_bar_reader.currency_codes(assets)
expected = all_expected[indices]

assert_equal(results, expected)

def test_listing_currency_for_nonexistent_asset(self):
reader = self.daily_bar_reader

valid_sid = max(self.assets)
valid_currency = reader.currency_codes(np.array([valid_sid]))[0]
invalid_sids = [-1, -2]

# XXX: We currently require at least one valid sid here, because the
# MultiCountryDailyBarReader needs one valid sid to be able to dispatch
# to a child reader. We could probably make that work, but there are no
# real-world cases where we expect to get all-invalid currency queries,
# so it's unclear whether we should do work to explicitly support such
# queries.
mixed = np.array(invalid_sids + [valid_sid])
result = self.daily_bar_reader.currency_codes(mixed)
expected = np.array(
[MISSING_CURRENCY_CODE] * 2 + [valid_currency],
dtype='S3'
)
assert_equal(result, expected)


class BcolzDailyBarTestCase(WithBcolzEquityDailyBarReader, _DailyBarsTestCase):
Expand Down
72 changes: 5 additions & 67 deletions tests/data/test_fx.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import itertools

import h5py
import pandas as pd
import numpy as np

from zipline.data.fx import DEFAULT_FX_RATE
from zipline.data.fx.hdf5 import HDF5FXRateReader, HDF5FXRateWriter

from zipline.testing.predicates import assert_equal
import zipline.testing.fixtures as zp_fixtures
Expand Down Expand Up @@ -59,33 +57,6 @@ def make_fx_rates(cls, fields, currencies, sessions):
'tokyo_mid': cls.tokyo_mid_rates,
}

@classmethod
def get_expected_rate_scalar(cls, rate, quote, base, dt):
"""Get the expected FX rate for the given scalar coordinates.
"""
if rate == DEFAULT_FX_RATE:
rate = cls.FX_RATES_DEFAULT_RATE

col = cls.fx_rates[rate][quote][base]
# PERF: We call this function a lot in this suite, and get_loc is
# surprisingly expensive, so optimizing it has a meaningful impact on
# overall suite performance. See test_fast_get_loc_ffilled_for
# assurance that this behaves the same as get_loc.
ix = fast_get_loc_ffilled(col.index.values, dt.asm8)
return col.values[ix]

@classmethod
def get_expected_rates(cls, rate, quote, bases, dts):
"""Get an array of expected FX rates for the given indices.
"""
out = np.empty((len(dts), len(bases)), dtype='float64')

for i, dt in enumerate(dts):
for j, base in enumerate(bases):
out[i, j] = cls.get_expected_rate_scalar(rate, quote, base, dt)

return out

@property
def reader(self):
raise NotImplementedError("Must be implemented by test suite.")
Expand All @@ -110,7 +81,7 @@ def test_scalar_lookup(self):
if quote == base:
assert_equal(result_scalar, 1.0)

expected = self.get_expected_rate_scalar(rate, quote, base, dt)
expected = self.get_expected_fx_rate_scalar(rate, quote, base, dt)
assert_equal(result_scalar, expected)

def test_vectorized_lookup(self):
Expand All @@ -135,7 +106,7 @@ def test_vectorized_lookup(self):
# ...And check that we get the expected result when querying
# for those dates/currencies.
result = self.reader.get_rates(rate, quote, bases, dts)
expected = self.get_expected_rates(rate, quote, bases, dts)
expected = self.get_expected_fx_rates(rate, quote, bases, dts)

assert_equal(result, expected)

Expand Down Expand Up @@ -205,47 +176,14 @@ class HDF5FXReaderTestCase(zp_fixtures.WithTmpDir,
@classmethod
def init_class_fixtures(cls):
super(HDF5FXReaderTestCase, cls).init_class_fixtures()

path = cls.tmpdir.getpath('fx_rates.h5')

# Set by WithFXRates.
sessions = cls.fx_rates_sessions

# Write in-memory data to h5 file.
with h5py.File(path, 'w') as h5_file:
writer = HDF5FXRateWriter(h5_file)
fx_data = ((rate, quote, quote_frame.values)
for rate, rate_dict in cls.fx_rates.items()
for quote, quote_frame in rate_dict.items())

writer.write(
dts=sessions.values,
currencies=np.array(cls.FX_RATES_CURRENCIES, dtype='S3'),
data=fx_data,
)

h5_file = cls.enter_class_context(h5py.File(path, 'r'))
cls.h5_fx_reader = HDF5FXRateReader(
h5_file,
default_rate=cls.FX_RATES_DEFAULT_RATE,
)
cls.h5_fx_reader = cls.write_h5_fx_rates(path)

@property
def reader(self):
return self.h5_fx_reader


def fast_get_loc_ffilled(dts, dt):
"""
Equivalent to dts.get_loc(dt, method='ffill'), but with reasonable
microperformance.
"""
ix = dts.searchsorted(dt, side='right') - 1
if ix < 0:
raise KeyError(dt)
return ix


class FastGetLocTestCase(zp_fixtures.ZiplineTestCase):

def test_fast_get_loc_ffilled(self):
Expand All @@ -258,12 +196,12 @@ def test_fast_get_loc_ffilled(self):
])

for dt in pd.date_range('2014-01-02', '2014-01-08'):
result = fast_get_loc_ffilled(dts.values, dt.asm8)
result = zp_fixtures.fast_get_loc_ffilled(dts.values, dt.asm8)
expected = dts.get_loc(dt, method='ffill')
assert_equal(result, expected)

with self.assertRaises(KeyError):
dts.get_loc(pd.Timestamp('2014-01-01'), method='ffill')

with self.assertRaises(KeyError):
fast_get_loc_ffilled(dts, pd.Timestamp('2014-01-01'))
zp_fixtures.fast_get_loc_ffilled(dts, pd.Timestamp('2014-01-01'))
12 changes: 12 additions & 0 deletions tests/utils/test_numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from toolz import curry
from toolz.curried.operator import ne

from zipline.testing.predicates import assert_equal
from zipline.utils.functional import mapall as lazy_mapall
from zipline.utils.numpy_utils import (
bytes_array_to_native_str_object_array,
is_float,
is_int,
is_datetime,
Expand Down Expand Up @@ -92,3 +94,13 @@ def test_is_datetime(self):

for bad_value in everything_but(datetime, CASES):
self.assertFalse(is_datetime(bad_value))


class ArrayUtilsTestCase(TestCase):

def test_bytes_array_to_native_str_object_array(self):
a = array([b'abc', b'def'], dtype='S3')
result = bytes_array_to_native_str_object_array(a)
expected = array(['abc', 'def'], dtype=object)

assert_equal(result, expected)
57 changes: 18 additions & 39 deletions zipline/currency.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,11 @@
from functools import partial, total_ordering

from functools import total_ordering
from iso4217 import Currency as ISO4217Currency

import numpy as np

_ALL_CURRENCIES = {}


def strs_to_sids(strs, category_num):
"""TODO: Improve this.
"""
out = np.full(len(strs), category_num << 50, dtype='i8')
casted_buffer = np.ndarray(
shape=out.shape,
dtype='S6',
buffer=out,
strides=out.strides,
)
casted_buffer[:] = np.array(strs, dtype='S6')
return out


def str_to_sid(str_, category_num):
return strs_to_sids([str_], category_num)[0]


iso_currency_to_sid = partial(str_to_sid, category_num=3)
# Special sentinel used to represent unknown or missing currencies.
MISSING_CURRENCY_CODE = 'XXX'


@total_ordering
Expand All @@ -48,15 +28,20 @@ def __new__(cls, code):
try:
return _ALL_CURRENCIES[code]
except KeyError:
try:
iso_currency = ISO4217Currency(code)
except ValueError:
raise ValueError(
"{!r} is not a valid currency code.".format(code)
)
# This isn't a real
if code == MISSING_CURRENCY_CODE:
name = "NO CURRENCY"
else:
try:
name = ISO4217Currency(code).currency_name
except ValueError:
raise ValueError(
"{!r} is not a valid currency code.".format(code)
)

obj = _ALL_CURRENCIES[code] = super(Currency, cls).__new__(cls)
obj._currency = iso_currency
obj._sid = iso_currency_to_sid(iso_currency.value)
obj._code = code
obj._name = name
return obj

@property
Expand All @@ -67,7 +52,7 @@ def code(self):
-------
code : str
"""
return self._currency.value
return self._code

@property
def name(self):
Expand All @@ -77,13 +62,7 @@ def name(self):
-------
name : str
"""
return self._currency.currency_name

@property
def sid(self):
"""Unique integer identifier for this currency.
"""
return self._sid
return self._name

def __eq__(self, other):
if type(self) != type(other):
Expand Down
14 changes: 12 additions & 2 deletions zipline/data/bcolz_daily_bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from toolz import compose
from trading_calendars import get_calendar

from zipline.currency import MISSING_CURRENCY_CODE
from zipline.data.session_bars import CurrencyAwareSessionBarReader
from zipline.data.bar_reader import (
NoDataAfterDate,
Expand Down Expand Up @@ -706,5 +707,14 @@ def get_value(self, sid, dt, field):
return price

def currency_codes(self, sids):
# TODO: Better handling for this.
return np.full(len(sids), b'USD', dtype='S3')
# XXX: This is pretty inefficient. This reader doesn't really support
# country codes, so we always either return USD or
# MISSING_CURRENCY_CODE if we don't know about the sid at all.
first_rows = self._first_rows
out = []
for sid in sids:
if sid in first_rows:
out.append('USD')
else:
out.append(MISSING_CURRENCY_CODE)
return np.array(out, dtype='S3')
Loading

0 comments on commit 3ea5916

Please sign in to comment.