Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor get_entry into common place and switch to built-in key-making … #105

Merged
merged 1 commit into from
Mar 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 13 additions & 6 deletions cachier/base_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,38 @@
# Copyright (c) 2016, Shay Palachy <shaypal5@gmail.com>

import abc # for the _BaseCore abstract base class
import functools


# pylint: disable-next=protected-access
_default_hash_params = functools.partial(functools._make_key, typed=False)


class _BaseCore():
__metaclass__ = abc.ABCMeta

def __init__(self, stale_after, next_time):
def __init__(self, stale_after, next_time, hash_params):
self.stale_after = stale_after
self.next_time = next_time
self.hash_func = hash_params if hash_params else _default_hash_params
self.func = None

def set_func(self, func):
"""Sets the function this core will use. This has to be set before
any method is called"""
self.func = func

def get_entry(self, args, kwds):
"""Returns the result mapped to the given arguments in this core's
cache, if such a mapping exists."""
key = self.hash_func(args, kwds)
return self.get_entry_by_key(key)

@abc.abstractmethod
def get_entry_by_key(self, key):
"""Returns the result mapped to the given key in this core's cache,
if such a mapping exists."""

@abc.abstractmethod
def get_entry(self, args, kwds, hash_params):
"""Returns the result mapped to the given arguments in this core's
cache, if such a mapping exists."""

@abc.abstractmethod
def set_entry(self, key, func_res):
"""Maps the given result to the given key in this core's cache."""
Expand Down
18 changes: 14 additions & 4 deletions cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,29 @@ def cachier(
core = _PickleCore( # pylint: disable=R0204
stale_after=stale_after,
next_time=next_time,
hash_params=hash_params,
reload=pickle_reload,
cache_dir=cache_dir,
separate_files=separate_files,
wait_for_calc_timeout=wait_for_calc_timeout
wait_for_calc_timeout=wait_for_calc_timeout,
)
elif backend == 'mongo':
if mongetter is None:
raise MissingMongetter(
'must specify ``mongetter`` when using the mongo core')
core = _MongoCore(
mongetter, stale_after, next_time, wait_for_calc_timeout)
mongetter=mongetter,
stale_after=stale_after,
next_time=next_time,
hash_params=hash_params,
wait_for_calc_timeout=wait_for_calc_timeout,
)
elif backend == 'memory':
core = _MemoryCore(stale_after=stale_after, next_time=next_time)
core = _MemoryCore(
stale_after=stale_after,
next_time=next_time,
hash_params=hash_params,
)
elif backend == 'redis':
raise NotImplementedError(
'A Redis backend has not yet been implemented. '
Expand All @@ -184,7 +194,7 @@ def func_wrapper(*args, **kwds): # pylint: disable=C0111,R0911
_print = print
if ignore_cache:
return func(*args, **kwds)
key, entry = core.get_entry(args, kwds, hash_params)
key, entry = core.get_entry(args, kwds)
if overwrite_cache:
return _calc_entry(core, key, func, args, kwds)
if entry is not None: # pylint: disable=R0101
Expand Down
9 changes: 2 additions & 7 deletions cachier/memory_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,15 @@ class _MemoryCore(_BaseCore):
See :class:`_BaseCore` documentation.
"""

def __init__(self, stale_after, next_time):
super().__init__(stale_after=stale_after, next_time=next_time)
def __init__(self, stale_after, next_time, hash_params):
super().__init__(stale_after, next_time, hash_params)
self.cache = {}
self.lock = threading.RLock()

def get_entry_by_key(self, key, reload=False): # pylint: disable=W0221
with self.lock:
return key, self.cache.get(key, None)

def get_entry(self, args, kwds, hash_params):
with self.lock:
key = args + tuple(sorted(kwds.items())) if hash_params is None else hash_params(args, kwds) # noqa: E501
return self.get_entry_by_key(key)

def set_entry(self, key, func_res):
with self.lock:
try:
Expand Down
13 changes: 3 additions & 10 deletions cachier/mongo_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ class _MongoCore(_BaseCore):
_INDEX_NAME = 'func_1_key_1'

def __init__(
self, mongetter, stale_after, next_time, wait_for_calc_timeout):
self, mongetter, stale_after, next_time,
hash_params, wait_for_calc_timeout):
if 'pymongo' not in sys.modules:
warnings.warn((
"Cachier warning: pymongo was not found. "
"MongoDB cores will not function."))
_BaseCore.__init__(self, stale_after, next_time)
super().__init__(stale_after, next_time, hash_params)
self.mongetter = mongetter
self.mongo_collection = self.mongetter()
self.wait_for_calc_timeout = wait_for_calc_timeout
Expand Down Expand Up @@ -81,14 +82,6 @@ def get_entry_by_key(self, key):
return key, entry
return key, None

def get_entry(self, args, kwds, hash_params):
key = pickle.dumps(
args + tuple(
sorted(kwds.items())
) if hash_params is None else hash_params(args, kwds)
)
return self.get_entry_by_key(key)

def set_entry(self, key, func_res):
thebytes = pickle.dumps(func_res)
self.mongo_collection.update_one(
Expand Down
11 changes: 3 additions & 8 deletions cachier/pickle_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def on_modified(self, event): # skipcq: PYL-W0613
self._check_calculation()

def __init__(
self, stale_after, next_time, reload, cache_dir, separate_files,
wait_for_calc_timeout,
self, stale_after, next_time, hash_params, reload,
cache_dir, separate_files, wait_for_calc_timeout,
):
_BaseCore.__init__(self, stale_after, next_time)
super().__init__(stale_after, next_time, hash_params)
self.cache = None
self.reload = reload
self.cache_dir = DEF_CACHIER_DIR
Expand Down Expand Up @@ -193,11 +193,6 @@ def get_entry_by_key(self, key, reload=False): # pylint: disable=W0221
self._reload_cache()
return key, self._get_cache().get(key, None)

def get_entry(self, args, kwds, hash_params):
key = args + tuple(sorted(kwds.items())) if hash_params is None else hash_params(args, kwds) # noqa: E501
# print('key type={}, key={}'.format(type(key), key))
return self.get_entry_by_key(key)

def set_entry(self, key, func_res):
key_data = {
'value': func_res,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_mongo_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_stalled_mongo_db_cache():
@cachier(mongetter=_test_mongetter)
def _stalled_func():
return 1
core = _MongoCore(_test_mongetter, None, False, 0)
core = _MongoCore(_test_mongetter, None, False, None, 0)
core.set_func(_stalled_func)
core.clear_cache()
with pytest.raises(RecalculationNeeded):
Expand All @@ -215,7 +215,7 @@ def _stalled_func():
@pytest.mark.mongo
def test_stalled_mong_db_core(monkeypatch):

def mock_get_entry(self, args, kwargs, hash_params): # skipcq: PYL-R0201, PYL-W0613 # noqa: E501
def mock_get_entry(self, args, kwargs): # skipcq: PYL-R0201, PYL-W0613 # noqa: E501
return "key", {'being_calculated': True}

def mock_get_entry_by_key(self, key): # skipcq: PYL-R0201, PYL-W0613
Expand All @@ -233,7 +233,7 @@ def _stalled_func():
res = _stalled_func()
assert res == 1

def mock_get_entry_2(self, args, kwargs, hash_params): # skipcq: PYL-W0613
def mock_get_entry_2(self, args, kwargs): # skipcq: PYL-W0613
entry = {
'being_calculated': True,
"value": 1,
Expand Down