diff --git a/CHANGES.txt b/CHANGES.txt index 46fac96..eb49deb 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,6 +1,8 @@ v1.2 - unreleased ================= +- In CachedUpdatePersister don't load model if we already have active + - Allow for use of multiple CachedUpdatePersisters Instances of CachedUpdatePersister now use their own name inside the diff --git a/palladium/persistence.py b/palladium/persistence.py index 4e21fe8..cb46fa8 100644 --- a/palladium/persistence.py +++ b/palladium/persistence.py @@ -569,10 +569,12 @@ class CachedUpdatePersister(ModelPersister): cache = process_store __pld_config_key__ = 'cachedupdatepersister_default' + _loaded_version = None def __init__(self, impl, - update_cache_rrule=None + update_cache_rrule=None, + check_version=True, ): """ :param ModelPersister impl: @@ -583,9 +585,15 @@ def __init__(self, :class:`dateutil.rrule.rrule` that determines when the cache will be updated. See :class:`~palladium.util.RruleThread` for details. + + :param bool check_version: + If set to `True`, I will perform a check and only load a new + model from the storage if my cached version differs from + what's the current active version. """ self.impl = impl self.update_cache_rrule = update_cache_rrule + self.check_version = check_version def initialize_component(self, config): self.use_cache = config.get('__mode__') != 'fit' @@ -611,9 +619,20 @@ def read(self, *args, **kwargs): @PluggableDecorator('update_model_decorators') def update_cache(self, *args, **kwargs): + active_version = None + + if self.check_version: + active_version = self.list_properties()['active-model'] + if self._loaded_version == (active_version, args, kwargs): + return + model = self.impl.read(*args, **kwargs) if model is not None: self.cache[self.__pld_config_key__] = model + + if self.check_version: + self._loaded_version = (active_version, args, kwargs) + return model @PluggableDecorator('write_model_decorators') diff --git a/palladium/tests/test_persistence.py b/palladium/tests/test_persistence.py index 8f7523b..631cd6c 100644 --- a/palladium/tests/test_persistence.py +++ b/palladium/tests/test_persistence.py @@ -769,9 +769,23 @@ def test_write(self, persister): persister.impl.write.assert_called_with('mymodel') def test_update_cache(self, persister): + persister.update_cache() + assert persister.read() is persister.impl.read.return_value + persister.impl.read.assert_called_with() + assert len(persister.impl.read.mock_calls) == 1 + + def test_update_cache_no_check_version(self, persister): + persister.check_version = False + persister.update_cache() + assert persister.read() is persister.impl.read.return_value + persister.impl.read.assert_called_with() + assert len(persister.impl.read.mock_calls) == 2 + + def test_update_cache_specific_version(self, persister): persister.update_cache(version=123) assert persister.read() is persister.impl.read.return_value persister.impl.read.assert_called_with(version=123) + assert len(persister.impl.read.mock_calls) == 2 def test_update_cache_rrule(self, process_store, CachedUpdatePersister, config):