From 2400124bdf6864e573406c173cfb02e68c776e1e Mon Sep 17 00:00:00 2001 From: scnerd Date: Wed, 4 Apr 2018 13:23:11 -0400 Subject: [PATCH] Implemented explicit lazy dictionary property for caching lazily computed values at specific indices --- miniutils/caching.py | 108 ++++++++++++++++++++++++++-------- requirements.txt | 1 - tests/test_cached_property.py | 71 +++++++++++++++++++++- 3 files changed, 154 insertions(+), 26 deletions(-) diff --git a/miniutils/caching.py b/miniutils/caching.py index 14b01b3..503667c 100644 --- a/miniutils/caching.py +++ b/miniutils/caching.py @@ -1,11 +1,10 @@ import functools -#from contextlib import contextmanager from threading import RLock -#import inspect +from functools import partial class CachedCollection: - IGNORED_GETS = ['get', 'union', 'intersection', 'difference', 'copy'] + IGNORED_GETS = ['get', 'union', 'intersection', 'difference', 'copy', 'keys', 'values', 'items'] def __init__(self, value, on_update, container_self, allow_update): self.collection = value @@ -15,11 +14,11 @@ def __init__(self, value, on_update, container_self, allow_update): def __getitem__(self, item): return self.collection[item] - def __missing__(self, key): - if not self.allow_update: - raise AttributeError("Attempted to perform an action (probably add) an unknown key") - self.collection.__missing__(key) - self.on_update() + # def __missing__(self, key): # This isn't a dict subclass, it's a wrapper, so this method will never get called + # if not self.allow_update: + # raise AttributeError("Attempted to perform an action (probably add) an unknown key") + # self.collection.__missing__(key) + # self.on_update() def __setitem__(self, key, value): if not self.allow_update: @@ -104,9 +103,9 @@ def __init__(self, *affects, settable=False, threadsafe=True, is_collection=Fals self.f = None CachedProperty.caches.append(self) - def __call__(self, f): + def __call__(self, f, name=None): self.f = f - self.name = name = f.__name__ + self.name = name = name or f.__name__ flag_name = '_need_' + name cache_name = '_' + name @@ -144,7 +143,8 @@ def inner_getter(inner_self): return getattr(inner_self, cache_name) def inner_deleter(inner_self): - assert getattr(inner_self, flag_name, True) or hasattr(inner_self, cache_name) + # assert not getattr(inner_self, flag_name, True) or hasattr(inner_self, cache_name) + # raise AttributeError("{} does not have a value for attribute {}".format(inner_self, name)) setattr(inner_self, flag_name, True) if hasattr(inner_self, cache_name): delattr(inner_self, cache_name) @@ -168,16 +168,76 @@ def inner_setter(inner_self, value): return property(fget=inner_getter, fset=inner_setter, fdel=inner_deleter, doc=self.f.__doc__) -# def _get_class_that_defined_method(method): -# """https://stackoverflow.com/questions/3589311/get-defining-class-of-unbound-method-object-in-python-3""" -# if inspect.ismethod(method): -# for cls in inspect.getmro(method.__self__.__class__): -# if cls.__dict__.get(method.__name__) is method: -# return cls -# method = method.__func__ # fallback to __qualname__ parsing -# if inspect.isfunction(method): -# cls = getattr(inspect.getmodule(method), -# method.__qualname__.split('.', 1)[0].rsplit('.', 1)[0]) -# if isinstance(cls, type): -# return cls -# return getattr(method, '__objclass__', None) # handle special descriptor objects +class _LazyIndexable: + def __init__(self, getter_closure, on_modified, settable=False, values=None): + self._cache = dict(values or {}) + self._closure = getter_closure + self._on_modified = on_modified + self.settable = settable + + def __getitem__(self, item): + if item not in self._cache: + self._cache[item] = self._closure(item) + self._on_modified() + return self._cache[item] + + def __setitem__(self, key, value): + if not self.settable: + raise AttributeError("{} is not settable".format(self)) + self._cache[key] = value + self._on_modified() + + def __delitem__(self, key): + del self._cache[key] + self._on_modified() + + @property + def __doc__(self): + return self._closure.__doc__ + + def update(self, new_values): + if not self.settable: + raise AttributeError("{} is not settable".format(self)) + self._cache.update(new_values) + self._on_modified() + + +class LazyDictionary: + caches = [] + + def __init__(self, *affects, allow_collection_mutation=True): + """Marks this indexable property to be a cached dictionary. Delete this property to remove the cached value and force it to be rerun. + + :param affects: Strings that list the names of the other properties in this class that are directly invalidated + when this property's value is altered + :param allow_collection_mutation: Whether or not the returned collection should allow its values to be altered + """ + self.affected_properties = affects + self.allow_mutation = allow_collection_mutation + + def __call__(self, f, name=None): + self.f = f + self.name = name = name or f.__name__ + cache_name = '_' + name + + def reset_dependents(inner_self): + for affected in self.affected_properties: + delattr(inner_self, affected) + + @functools.wraps(f) + def inner_getter(inner_self): + if not hasattr(inner_self, cache_name): + new_indexable = _LazyIndexable(functools.wraps(f)(partial(f, inner_self)), + partial(reset_dependents, inner_self), + self.allow_mutation) + setattr(inner_self, cache_name, new_indexable) + return getattr(inner_self, cache_name) + + def inner_deleter(inner_self): + if hasattr(inner_self, cache_name): + delattr(inner_self, cache_name) + # If we make this recursion conditional on the cache existing, we prevent dependency cycles from + # breaking the code + reset_dependents(inner_self) + + return property(fget=inner_getter, fdel=inner_deleter, doc=self.f.__doc__) diff --git a/requirements.txt b/requirements.txt index 162ffc4..9d2ffc7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ tqdm pycontracts coloredlogs -astor coveralls coverage numpy diff --git a/tests/test_cached_property.py b/tests/test_cached_property.py index f15993b..a474af3 100644 --- a/tests/test_cached_property.py +++ b/tests/test_cached_property.py @@ -3,7 +3,7 @@ import numpy as np -from miniutils.caching import CachedProperty +from miniutils.caching import CachedProperty, LazyDictionary from miniutils.capture_output import captured_output @@ -105,6 +105,32 @@ def target(self): return True +class WithCachedDict: + def __init__(self): + self.calls = [] + + @CachedProperty() + def a(self): + self.calls.append('a') + return self.b + 2 + + @CachedProperty('a') + def b(self): + self.calls.append('b') + return self.f[1] + self.f[2] + + @LazyDictionary('b', allow_collection_mutation=True) + def f(self, x): + self.calls.append('f({})'.format(x)) + return x ** 2 + + @LazyDictionary(allow_collection_mutation=False) + def g(self, x): + """G docstring""" + self.calls.append('g({})'.format(x)) + return x ** 2 + + class TestCachedProperty(TestCase): def test_matrix(self): np.random.seed(0) @@ -208,6 +234,15 @@ def test_dict_update(self): self.assertEqual(i.basic_dict['new_key'], -1) self.assertTrue(i._need_target) + def test_dict_contains(self): + i = CollectionProperties() + self.assertIn('a', i.basic_dict) + self.assertNotIn('d', i.basic_dict) + self.assertRaises(KeyError, lambda: i.basic_dict['x']) + self.assertIn('a', i.locked_dict) + self.assertNotIn('d', i.locked_dict) + self.assertRaises(KeyError, lambda: i.locked_dict['x']) + def test_mutable_list_properties(self): i = CollectionProperties() self.assertTrue(i.target) @@ -239,3 +274,37 @@ def test_mutable_list_properties(self): i.basic_set.difference(i.basic_set) self.assertFalse(i._need_target) + + def test_cached_dict(self): + w = WithCachedDict() + self.assertEqual(w.a, 7) + del w.b + self.assertEqual(w.b, 5) + del w.f + self.assertEqual(w.f[5], 25) + self.assertEqual(w.f[5], 25) + self.assertEqual(w.f[5], 25) + del w.f[5] + self.assertEqual(w.f[4], 16) + self.assertEqual(w.f[5], 25) + self.assertEqual(w.f[2], 4) + self.assertEqual(w.a, 7) + w.f[2] = 7 + self.assertEqual(w.b, 8) + self.assertEqual(w.a, 10) + w.f.update({1: 0, 2: 0}) + self.assertEqual(w.b, 0) + self.assertEqual(w.a, 2) + + self.assertEqual(w.g[3], 9) + self.assertRaises(AttributeError, w.g.update, {3: 0}) + try: + w.g[3] = 0 + self.fail("Set value on immutable lazy dict") + except AttributeError: + pass + + self.assertListEqual(w.calls, ['a', 'b', 'f(1)', 'f(2)', 'b', 'f(5)', 'f(4)', 'f(5)', 'f(2)', 'a', 'b', 'f(1)', + 'b', 'a', 'b', 'a', 'g(3)']) + + self.assertIn('G docstring', w.g.__doc__)