Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add expiring cache. #1130

Merged
merged 2 commits into from Apr 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/whatsnew/0.9.1.txt
Expand Up @@ -22,6 +22,11 @@ Enhancements
factor to only compute over stocks for which the filter returns True, rather
than always computing over the entire universe of stocks. (:issue:`1095`)

* Added :class:`zipline.utils.cache.ExpiringCache`.
A cache which wraps entries in a :class:`zipline.utils.cache.CachedObject`,
which manages expiration of entries based on the `dt` supplied to the `get`
method.

Experimental Features
~~~~~~~~~~~~~~~~~~~~~

Expand Down
41 changes: 40 additions & 1 deletion tests/utils/test_cache.py
Expand Up @@ -2,7 +2,7 @@

from pandas import Timestamp, Timedelta

from zipline.utils.cache import CachedObject, Expired
from zipline.utils.cache import CachedObject, Expired, ExpiringCache


class CachedObjectTestCase(TestCase):
Expand All @@ -19,3 +19,42 @@ def test_cached_object(self):
with self.assertRaises(Expired) as e:
obj.unwrap(after)
self.assertEqual(e.exception.args, (expiry,))


class ExpiringCacheTestCase(TestCase):

def test_expiring_cache(self):
expiry_1 = Timestamp('2014')
before_1 = expiry_1 - Timedelta('1 minute')
after_1 = expiry_1 + Timedelta('1 minute')

expiry_2 = Timestamp('2015')
after_2 = expiry_1 + Timedelta('1 minute')

expiry_3 = Timestamp('2016')

cache = ExpiringCache()

cache.set('foo', 1, expiry_1)
cache.set('bar', 2, expiry_2)

self.assertEqual(cache.get('foo', before_1), 1)
# Unwrap on expiry is allowed.
self.assertEqual(cache.get('foo', expiry_1), 1)

with self.assertRaises(KeyError) as e:
self.assertEqual(cache.get('foo', after_1))
self.assertEqual(e.exception.args, ('foo',))

# Should raise same KeyError after deletion.
with self.assertRaises(KeyError) as e:
self.assertEqual(cache.get('foo', before_1))
self.assertEqual(e.exception.args, ('foo',))

# Second value should still exist.
self.assertEqual(cache.get('bar', after_2), 2)

# Should raise similar KeyError on non-existent key.
with self.assertRaises(KeyError) as e:
self.assertEqual(cache.get('baz', expiry_3))
self.assertEqual(e.exception.args, ('baz',))
55 changes: 55 additions & 0 deletions zipline/utils/cache.py
Expand Up @@ -57,3 +57,58 @@ def unwrap(self, dt):
if dt > self.expires:
raise Expired(self.expires)
return self.value


class ExpiringCache(object):
"""
A cache of multiple CachedObjects, which returns the wrapped the value
or raises and deletes the CachedObject if the value has expired.

Parameters
----------
cache : dict-like
An instance of a dict-like object which needs to support at least:
`__del__`, `__getitem__`, `__setitem__`
If `None`, than a dict is used as a default.

Methods
-------
get(self, key, dt)
Get the value of a cached object for the given `key` at `dt`, if the
CachedObject has expired then the object is removed from the cache,
and `KeyError` is raised.

set(self, key, value, expiration_dt)
Add a new `value` to the cache at `dt` wrapped in a CachedObject which
expires at `expiration_dt`.

Usage
-----
>>> from pandas import Timestamp, Timedelta
>>> expires = Timestamp('2014', tz='UTC')
>>> value = 1
>>> cache = ExpiringCache()
>>> cache.set('foo', value, expires)
>>> cache.get('foo', expires - Timedelta('1 minute'))
1
>>> cache.get('foo', expires + Timedelta('1 minute'))
Traceback (most recent call last):
...
KeyError: 'foo'
"""

def __init__(self, cache=None):
if cache is not None:
self._cache = cache
else:
self._cache = {}

def get(self, key, dt):
try:
return self._cache[key].unwrap(dt)
except Expired:
del self._cache[key]
raise KeyError(key)

def set(self, key, value, expiration_dt):
self._cache[key] = CachedObject(value, expiration_dt)