Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 84 additions & 24 deletions miniutils/caching.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import functools
#from contextlib import contextmanager
from threading import RLock
#import inspect
from functools import partial


class CachedCollection:
IGNORED_GETS = ['get', 'union', 'intersection', 'difference', 'copy']
IGNORED_GETS = ['get', 'union', 'intersection', 'difference', 'copy', 'keys', 'values', 'items']

def __init__(self, value, on_update, container_self, allow_update):
self.collection = value
Expand All @@ -15,11 +14,11 @@ def __init__(self, value, on_update, container_self, allow_update):
def __getitem__(self, item):
return self.collection[item]

def __missing__(self, key):
if not self.allow_update:
raise AttributeError("Attempted to perform an action (probably add) an unknown key")
self.collection.__missing__(key)
self.on_update()
# def __missing__(self, key): # This isn't a dict subclass, it's a wrapper, so this method will never get called
# if not self.allow_update:
# raise AttributeError("Attempted to perform an action (probably add) an unknown key")
# self.collection.__missing__(key)
# self.on_update()

def __setitem__(self, key, value):
if not self.allow_update:
Expand Down Expand Up @@ -104,9 +103,9 @@ def __init__(self, *affects, settable=False, threadsafe=True, is_collection=Fals
self.f = None
CachedProperty.caches.append(self)

def __call__(self, f):
def __call__(self, f, name=None):
self.f = f
self.name = name = f.__name__
self.name = name = name or f.__name__
flag_name = '_need_' + name
cache_name = '_' + name

Expand Down Expand Up @@ -144,7 +143,8 @@ def inner_getter(inner_self):
return getattr(inner_self, cache_name)

def inner_deleter(inner_self):
assert getattr(inner_self, flag_name, True) or hasattr(inner_self, cache_name)
# assert not getattr(inner_self, flag_name, True) or hasattr(inner_self, cache_name)
# raise AttributeError("{} does not have a value for attribute {}".format(inner_self, name))
setattr(inner_self, flag_name, True)
if hasattr(inner_self, cache_name):
delattr(inner_self, cache_name)
Expand All @@ -168,16 +168,76 @@ def inner_setter(inner_self, value):
return property(fget=inner_getter, fset=inner_setter, fdel=inner_deleter, doc=self.f.__doc__)


# def _get_class_that_defined_method(method):
# """https://stackoverflow.com/questions/3589311/get-defining-class-of-unbound-method-object-in-python-3"""
# if inspect.ismethod(method):
# for cls in inspect.getmro(method.__self__.__class__):
# if cls.__dict__.get(method.__name__) is method:
# return cls
# method = method.__func__ # fallback to __qualname__ parsing
# if inspect.isfunction(method):
# cls = getattr(inspect.getmodule(method),
# method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0])
# if isinstance(cls, type):
# return cls
# return getattr(method, '__objclass__', None) # handle special descriptor objects
class _LazyIndexable:
def __init__(self, getter_closure, on_modified, settable=False, values=None):
self._cache = dict(values or {})
self._closure = getter_closure
self._on_modified = on_modified
self.settable = settable

def __getitem__(self, item):
if item not in self._cache:
self._cache[item] = self._closure(item)
self._on_modified()
return self._cache[item]

def __setitem__(self, key, value):
if not self.settable:
raise AttributeError("{} is not settable".format(self))
self._cache[key] = value
self._on_modified()

def __delitem__(self, key):
del self._cache[key]
self._on_modified()

@property
def __doc__(self):
return self._closure.__doc__

def update(self, new_values):
if not self.settable:
raise AttributeError("{} is not settable".format(self))
self._cache.update(new_values)
self._on_modified()


class LazyDictionary:
caches = []

def __init__(self, *affects, allow_collection_mutation=True):
"""Marks this indexable property to be a cached dictionary. Delete this property to remove the cached value and force it to be rerun.

:param affects: Strings that list the names of the other properties in this class that are directly invalidated
when this property's value is altered
:param allow_collection_mutation: Whether or not the returned collection should allow its values to be altered
"""
self.affected_properties = affects
self.allow_mutation = allow_collection_mutation

def __call__(self, f, name=None):
self.f = f
self.name = name = name or f.__name__
cache_name = '_' + name

def reset_dependents(inner_self):
for affected in self.affected_properties:
delattr(inner_self, affected)

@functools.wraps(f)
def inner_getter(inner_self):
if not hasattr(inner_self, cache_name):
new_indexable = _LazyIndexable(functools.wraps(f)(partial(f, inner_self)),
partial(reset_dependents, inner_self),
self.allow_mutation)
setattr(inner_self, cache_name, new_indexable)
return getattr(inner_self, cache_name)

def inner_deleter(inner_self):
if hasattr(inner_self, cache_name):
delattr(inner_self, cache_name)
# If we make this recursion conditional on the cache existing, we prevent dependency cycles from
# breaking the code
reset_dependents(inner_self)

return property(fget=inner_getter, fdel=inner_deleter, doc=self.f.__doc__)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
tqdm
pycontracts
coloredlogs
astor
coveralls
coverage
numpy
Expand Down
71 changes: 70 additions & 1 deletion tests/test_cached_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from miniutils.caching import CachedProperty
from miniutils.caching import CachedProperty, LazyDictionary
from miniutils.capture_output import captured_output


Expand Down Expand Up @@ -105,6 +105,32 @@ def target(self):
return True


class WithCachedDict:
def __init__(self):
self.calls = []

@CachedProperty()
def a(self):
self.calls.append('a')
return self.b + 2

@CachedProperty('a')
def b(self):
self.calls.append('b')
return self.f[1] + self.f[2]

@LazyDictionary('b', allow_collection_mutation=True)
def f(self, x):
self.calls.append('f({})'.format(x))
return x ** 2

@LazyDictionary(allow_collection_mutation=False)
def g(self, x):
"""G docstring"""
self.calls.append('g({})'.format(x))
return x ** 2


class TestCachedProperty(TestCase):
def test_matrix(self):
np.random.seed(0)
Expand Down Expand Up @@ -208,6 +234,15 @@ def test_dict_update(self):
self.assertEqual(i.basic_dict['new_key'], -1)
self.assertTrue(i._need_target)

def test_dict_contains(self):
i = CollectionProperties()
self.assertIn('a', i.basic_dict)
self.assertNotIn('d', i.basic_dict)
self.assertRaises(KeyError, lambda: i.basic_dict['x'])
self.assertIn('a', i.locked_dict)
self.assertNotIn('d', i.locked_dict)
self.assertRaises(KeyError, lambda: i.locked_dict['x'])

def test_mutable_list_properties(self):
i = CollectionProperties()
self.assertTrue(i.target)
Expand Down Expand Up @@ -239,3 +274,37 @@ def test_mutable_list_properties(self):

i.basic_set.difference(i.basic_set)
self.assertFalse(i._need_target)

def test_cached_dict(self):
w = WithCachedDict()
self.assertEqual(w.a, 7)
del w.b
self.assertEqual(w.b, 5)
del w.f
self.assertEqual(w.f[5], 25)
self.assertEqual(w.f[5], 25)
self.assertEqual(w.f[5], 25)
del w.f[5]
self.assertEqual(w.f[4], 16)
self.assertEqual(w.f[5], 25)
self.assertEqual(w.f[2], 4)
self.assertEqual(w.a, 7)
w.f[2] = 7
self.assertEqual(w.b, 8)
self.assertEqual(w.a, 10)
w.f.update({1: 0, 2: 0})
self.assertEqual(w.b, 0)
self.assertEqual(w.a, 2)

self.assertEqual(w.g[3], 9)
self.assertRaises(AttributeError, w.g.update, {3: 0})
try:
w.g[3] = 0
self.fail("Set value on immutable lazy dict")
except AttributeError:
pass

self.assertListEqual(w.calls, ['a', 'b', 'f(1)', 'f(2)', 'b', 'f(5)', 'f(4)', 'f(5)', 'f(2)', 'a', 'b', 'f(1)',
'b', 'a', 'b', 'a', 'g(3)'])

self.assertIn('G docstring', w.g.__doc__)