Skip to content

Commit

Permalink
Merge 079d699 into 8443934
Browse files Browse the repository at this point in the history
  • Loading branch information
Scott Sanderson authored Jan 30, 2020
2 parents 8443934 + 079d699 commit daaa0a0
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 76 deletions.
74 changes: 57 additions & 17 deletions tests/data/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,23 @@ def test_scalar_lookup(self):
reader = self.reader

rates = self.FX_RATES_RATE_NAMES
currencies = self.FX_RATES_CURRENCIES
dates = pd.date_range(self.FX_RATES_START_DATE, self.FX_RATES_END_DATE)

cases = itertools.product(rates, currencies, currencies, dates)
quotes = self.FX_RATES_CURRENCIES
bases = self.FX_RATES_CURRENCIES + [None]
dates = pd.date_range(
self.FX_RATES_START_DATE - pd.Timedelta('1 day'),
self.FX_RATES_END_DATE,
)
cases = itertools.product(rates, quotes, bases, dates)

for rate, quote, base, dt in cases:
dts = pd.DatetimeIndex([dt], tz='UTC')
bases = np.array([base])
bases = np.array([base], dtype=object)

result = reader.get_rates(rate, quote, bases, dts)
assert_equal(result.shape, (1, 1))

result_scalar = result[0, 0]
if quote == base:
if dt >= self.FX_RATES_START_DATE and quote == base:
assert_equal(result_scalar, 1.0)

expected = self.get_expected_fx_rate_scalar(rate, quote, base, dt)
Expand All @@ -93,12 +96,16 @@ def test_scalar_lookup(self):
def test_2d_lookup(self):
rand = np.random.RandomState(42)

dates = pd.date_range(self.FX_RATES_START_DATE, self.FX_RATES_END_DATE)
dates = pd.date_range(
self.FX_RATES_START_DATE - pd.Timedelta('2 days'),
self.FX_RATES_END_DATE
)
rates = self.FX_RATES_RATE_NAMES + [DEFAULT_FX_RATE]
currencies = self.FX_RATES_CURRENCIES
possible_quotes = self.FX_RATES_CURRENCIES
possible_bases = self.FX_RATES_CURRENCIES + [None]

# For every combination of rate name and quote currency...
for rate, quote in itertools.product(rates, currencies):
for rate, quote in itertools.product(rates, possible_quotes):

# Choose N random distinct days...
for ndays in 1, 2, 7, 20:
Expand All @@ -107,7 +114,10 @@ def test_2d_lookup(self):

# Choose M random possibly-non-distinct currencies...
for nbases in 1, 2, 10, 200:
bases = rand.choice(currencies, nbases, replace=True)
bases = (
rand.choice(possible_bases, nbases, replace=True)
.astype(object)
)

# ...And check that we get the expected result when querying
# for those dates/currencies.
Expand All @@ -119,18 +129,25 @@ def test_2d_lookup(self):
def test_columnar_lookup(self):
rand = np.random.RandomState(42)

dates = pd.date_range(self.FX_RATES_START_DATE, self.FX_RATES_END_DATE)
dates = pd.date_range(
self.FX_RATES_START_DATE - pd.Timedelta('2 days'),
self.FX_RATES_END_DATE,
)
rates = self.FX_RATES_RATE_NAMES + [DEFAULT_FX_RATE]
currencies = self.FX_RATES_CURRENCIES
possible_quotes = self.FX_RATES_CURRENCIES
possible_bases = self.FX_RATES_CURRENCIES + [None]
reader = self.reader

# For every combination of rate name and quote currency...
for rate, quote in itertools.product(rates, currencies):
for rate, quote in itertools.product(rates, possible_quotes):
for N in 1, 2, 10, 200:
# Choose N (date, base) pairs randomly with replacement.
dts_raw = rand.choice(dates, N, replace=True)
dts = pd.DatetimeIndex(dts_raw, tz='utc').sort_values()
bases = rand.choice(currencies, N, replace=True)
dts = pd.DatetimeIndex(dts_raw, tz='utc')
bases = (
rand.choice(possible_bases, N, replace=True)
.astype(object)
)

# ... And check that we get the expected result when querying
# for those dates/currencies.
Expand Down Expand Up @@ -175,27 +192,50 @@ def test_load_everything(self):
assert_equal(london_result, london_rates.values)

def test_read_before_start_date(self):
# Reads from before the start of our data should emit NaN. We do this
# because, for some Pipeline loaders, it's hard to put a lower bound on
# input asof dates, so we end up making queries for asof_dates that
# might be before the start of FX data. When that happens, we want to
# emit NaN, but we don't want to fail.
for bad_date in (self.FX_RATES_START_DATE - pd.Timedelta('1 day'),
self.FX_RATES_START_DATE - pd.Timedelta('1000 days')):

for rate in self.FX_RATES_RATE_NAMES:
quote = 'USD'
bases = np.array(['CAD'], dtype=object)
dts = pd.DatetimeIndex([bad_date])
with self.assertRaises(ValueError):
self.reader.get_rates(rate, quote, bases, dts)
result = self.reader.get_rates(rate, quote, bases, dts)
assert_equal(result.shape, (1, 1))
assert_equal(np.nan, result[0, 0])

def test_read_after_end_date(self):
# Reads from **after** the end of our data, on the other hand, should
# fail. We can always upper bound the relevant asofs that we're
# interested in, and having fx rates forward-fill past the end of data
# is confusing and takes a while to debug.
for bad_date in (self.FX_RATES_END_DATE + pd.Timedelta('1 day'),
self.FX_RATES_END_DATE + pd.Timedelta('1000 days')):

for rate in self.FX_RATES_RATE_NAMES:
quote = 'USD'
bases = np.array(['CAD'], dtype=object)
dts = pd.DatetimeIndex([bad_date])

with self.assertRaises(ValueError):
self.reader.get_rates(rate, quote, bases, dts)

with self.assertRaises(ValueError):
self.reader.get_rates_columnar(rate, quote, bases, dts)

def test_read_unknown_base(self):
for rate in self.FX_RATES_RATE_NAMES:
quote = 'USD'
for unknown_base in 'XXX', None:
bases = np.array(['XXX'], dtype=object)
dts = pd.DatetimeIndex([self.FX_RATES_START_DATE])
result = self.reader.get_rates(rate, quote, bases, dts)[0, 0]
assert_equal(result, np.nan)


class InMemoryFXReaderTestCase(_FXReaderTestCase):

Expand Down
13 changes: 10 additions & 3 deletions zipline/data/fx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

from zipline.utils.sentinel import sentinel
from zipline.lib._factorize import factorize_strings

DEFAULT_FX_RATE = sentinel('DEFAULT_FX_RATE')

Expand Down Expand Up @@ -127,15 +128,21 @@ def get_rates_columnar(self, rate, quote, bases, dts):
may appear multiple times.
dts : np.DatetimeIndex
Datetimes for which to load rates. The same value may appear
multiple times, but datetimes must be sorted in ascending order and
localized to UTC.
multiple times, and datetimes do not need to be sorted because
`np.unique` sorts in ascending order.
"""
if len(bases) != len(dts):
raise ValueError(
"len(bases) ({}) != len(dts) ({})".format(len(bases), len(dts))
)

unique_bases, bases_ix = np.unique(bases, return_inverse=True)
# TODO: Casting `bases` to str here is a temporary fix for the bug
# where having any `None` in `bases` causes `np.unique` to error.
bases_ix, unique_bases, _ = factorize_strings(
bases,
missing_value=None,
sort=True,
)
unique_dts, dts_ix = np.unique(dts.values, return_inverse=True)
rates_2d = self.get_rates(
rate,
Expand Down
77 changes: 38 additions & 39 deletions zipline/data/fx/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from zipline.utils.numpy_utils import bytes_array_to_native_str_object_array

from .base import FXRateReader, DEFAULT_FX_RATE
from .utils import check_dts, is_sorted_ascending

HDF5_FX_VERSION = 0

Expand Down Expand Up @@ -189,7 +190,7 @@ def get_rates(self, rate, quote, bases, dts):
if rate == DEFAULT_FX_RATE:
rate = self._default_rate

self._check_dts(self.dts, dts)
check_dts(self.dts, dts)

row_ixs = self.dts.searchsorted(dts, side='right') - 1
col_ixs = self.currencies.get_indexer(bases)
Expand All @@ -204,46 +205,48 @@ def get_rates(self, rate, quote, bases, dts):

# OPTIMIZATION: Row indices correspond to dates, which must be in
# sorted order. Rather than reading the entire dataset from h5, we can
# read just the interval from min_row to max_row inclusive.
# read just the interval from min_row to max_row inclusive
#
# We don't bother with a similar optimization for columns because in
# expectation we're going to load most of the

# array, so it's easier to pull all columns and reindex in memory. For
# rows, however, a quick and easy optimization is to pull just the
# slice from min(row_ixs) to max(row_ixs).
min_row = row_ixs[0]
max_row = row_ixs[-1]
rows = dataset[min_row:max_row + 1] # +1 to be inclusive of end

out = rows[row_ixs - min_row][:, col_ixs]
# However, we also need to handle two important edge cases:
#
# 1. row_ixs contains -1 for dts before the start of self.dts.
# 2. col_ixs contains -1 for any currencies we don't know about.
#
# If either of the above cases obtains, we want to return NaN for the
# corresponding output locations.

# get_indexer returns -1 for failed lookups. Fill these in with NaN.
# We handle (1) by reading raw data into a buffer with one extra
# row. When we then apply the row index to permute the raw data into
# the correct order, any rows with values of -1 will pull from the
# extra row, which will always contain NaN>
#
# We handle (2) by overwriting columns with indices of -1 with NaN as a
# postprocessing step.
slice_begin = max(row_ixs[0], 0)
slice_end = max(row_ixs[-1], 0) + 1 # +1 to be inclusive of end date.

# Allocate a buffer full of NaNs with one extra row/column. See
# OPTIMIZATION notes above.
buf = np.full(
(slice_end - slice_begin + 1, len(self.currencies)),
np.nan,
)

# Read data into all but the last row/column of the buffer.
dataset.read_direct(
buf[:-1],
np.s_[slice_begin:slice_end],
)

# Permute the rows into place, pulling from the empty NaN locations for
# row/column indices of -1.
out = buf[:, col_ixs][row_ixs - slice_begin]

# Fill missing columns with NaN. See OPTIMIZATION notes above.
out[:, col_ixs == -1] = np.nan

return out

def _check_dts(self, stored, requested):
"""Validate that requested dates are in bounds for what we have stored.
"""
request_start, request_end = requested[[0, -1]]
data_start, data_end = stored[[0, -1]]

if request_start < data_start:
raise ValueError(
"Requested fx rates starting at {}, but data starts at {}"
.format(request_start, data_start)
)

if request_end > data_end:
raise ValueError(
"Requested fx rates ending at {}, but data ends at {}"
.format(request_end, data_end)
)

if not is_sorted_ascending(requested):
raise ValueError("Requested fx rates with non-ascending dts.")


class HDF5FXRateWriter(object):
"""Writer class for HDF5 files consumed by HDF5FXRateReader.
Expand Down Expand Up @@ -312,7 +315,3 @@ def _write_data_group(self, dts, currencies, data):

def _log_writing(self, *path):
log.debug("Writing {}", '/'.join(path))


def is_sorted_ascending(array):
return (np.maximum.accumulate(array) <= array).all()
25 changes: 8 additions & 17 deletions zipline/data/fx/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Interface and definitions for foreign exchange rate readers.
"""
from interface import implements
import numpy as np

from .base import FXRateReader, DEFAULT_FX_RATE
from .utils import check_dts


class InMemoryFXRateReader(implements(FXRateReader)):
Expand Down Expand Up @@ -34,7 +36,7 @@ def get_rates(self, rate, quote, bases, dts):

df = self._data[rate][quote]

self._check_dts(df.index, dts)
check_dts(df.index, dts)

# Get raw values out of the frame.
#
Expand All @@ -51,22 +53,11 @@ def get_rates(self, rate, quote, bases, dts):
values = df.values
row_ixs = df.index.searchsorted(dts, side='right') - 1
col_ixs = df.columns.get_indexer(bases)
return values[row_ixs][:, col_ixs]

def _check_dts(self, stored, requested):
"""Validate that requested dates are in bounds for what we have stored.
"""
request_start, request_end = requested[[0, -1]]
data_start, data_end = stored[[0, -1]]
out = values[:, col_ixs][row_ixs]

if request_start < data_start:
raise ValueError(
"Requested fx rates starting at {}, but data starts at {}"
.format(request_start, data_start)
)
# Handle dates before start and unknown bases.
out[row_ixs == -1] = np.nan
out[:, col_ixs == -1] = np.nan

if request_end > data_end:
raise ValueError(
"Requested fx rates ending at {}, but data ends at {}"
.format(request_end, data_end)
)
return out
23 changes: 23 additions & 0 deletions zipline/data/fx/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np


def check_dts(stored_dts, requested_dts):
"""
Validate that ``requested_dts`` are valid for querying from an FX reader
that has data for ``stored_dts``.
"""
request_end = requested_dts[-1]
data_end = stored_dts[-1]

if not is_sorted_ascending(requested_dts):
raise ValueError("Requested fx rates with non-ascending dts.")

if request_end > data_end:
raise ValueError(
"Requested fx rates ending at {}, but data ends at {}"
.format(request_end, data_end)
)


def is_sorted_ascending(array):
return (np.maximum.accumulate(array) <= array).all()
8 changes: 8 additions & 0 deletions zipline/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,10 +2196,18 @@ def write_h5_fx_rates(cls, path):
def get_expected_fx_rate_scalar(cls, rate, quote, base, dt):
"""Get the expected FX rate for the given scalar coordinates.
"""
if base is None:
return np.nan

if rate == DEFAULT_FX_RATE:
rate = cls.FX_RATES_DEFAULT_RATE

col = cls.fx_rates[rate][quote][base]
if dt < col.index[0]:
return np.nan
elif dt > col.index[-1]:
raise ValueError("dt={} > max dt={}".format(dt, col.index[-1]))

# PERF: We call this function a lot in some suites, and get_loc is
# surprisingly expensive, so optimizing it has a meaningful impact on
# overall suite performance. See test_fast_get_loc_ffilled_for
Expand Down

0 comments on commit daaa0a0

Please sign in to comment.