Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions cachier/cores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime

from .._types import HashFunc
from .base import _BaseCore
from .base import _BaseCore, _get_func_str


class _MemoryCore(_BaseCore):
Expand All @@ -14,20 +14,23 @@ def __init__(self, hash_func: HashFunc, wait_for_calc_timeout: int):
super().__init__(hash_func, wait_for_calc_timeout)
self.cache = {}

def _hash_func_key(self, key):
return f"{_get_func_str(self.func)}:{key}"

def get_entry_by_key(self, key, reload=False):
with self.lock:
return key, self.cache.get(key, None)
return key, self.cache.get(self._hash_func_key(key), None)

def set_entry(self, key, func_res):
with self.lock:
try:
# we need to retain the existing condition so that
# mark_entry_not_calculated can notify all possibly-waiting
# threads about it
cond = self.cache[key]["condition"]
cond = self.cache[self._hash_func_key(key)]["condition"]
except KeyError: # pragma: no cover
cond = None
self.cache[key] = {
self.cache[self._hash_func_key(key)] = {
"value": func_res,
"time": datetime.now(),
"stale": False,
Expand All @@ -40,10 +43,10 @@ def mark_entry_being_calculated(self, key):
condition = threading.Condition()
# condition.acquire()
try:
self.cache[key]["being_calculated"] = True
self.cache[key]["condition"] = condition
self.cache[self._hash_func_key(key)]["being_calculated"] = True
self.cache[self._hash_func_key(key)]["condition"] = condition
except KeyError:
self.cache[key] = {
self.cache[self._hash_func_key(key)] = {
"value": None,
"time": datetime.now(),
"stale": False,
Expand All @@ -54,7 +57,7 @@ def mark_entry_being_calculated(self, key):
def mark_entry_not_calculated(self, key):
with self.lock:
try:
entry = self.cache[key]
entry = self.cache[self._hash_func_key(key)]
except KeyError: # pragma: no cover
return # that's ok, we don't need an entry in that case
entry["being_calculated"] = False
Expand All @@ -67,13 +70,13 @@ def mark_entry_not_calculated(self, key):

def wait_on_entry_calc(self, key):
with self.lock: # pragma: no cover
entry = self.cache[key]
entry = self.cache[self._hash_func_key(key)]
if not entry["being_calculated"]:
return entry["value"]
entry["condition"].acquire()
entry["condition"].wait()
entry["condition"].release()
return self.cache[key]["value"]
return self.cache[self._hash_func_key(key)]["value"]

def clear_cache(self):
with self.lock:
Expand Down
78 changes: 35 additions & 43 deletions cachier/cores/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,38 +86,27 @@ def __init__(
self.separate_files = _update_with_defaults(
separate_files, "separate_files"
)
self._cache_fname = None
self._cache_fpath = None

@property
def cache_fname(self) -> str:
if self._cache_fname is None:
fname = f".{self.func.__module__}.{self.func.__qualname__}"
self._cache_fname = fname.replace("<", "_").replace(">", "_")
return self._cache_fname
fname = f".{self.func.__module__}.{self.func.__qualname__}"
return fname.replace("<", "_").replace(">", "_")

@property
def cache_fpath(self) -> str:
if self._cache_fpath is None:
os.makedirs(self.cache_dir, exist_ok=True)
self._cache_fpath = os.path.abspath(
os.path.join(
os.path.realpath(self.cache_dir), self.cache_fname
)
)
return self._cache_fpath
os.makedirs(self.cache_dir, exist_ok=True)
return os.path.abspath(
os.path.join(os.path.realpath(self.cache_dir), self.cache_fname)
)

def _reload_cache(self):
with self.lock:
try:
with portalocker.Lock(
self.cache_fpath, mode="rb"
) as cache_file:
try:
self.cache = pickle.load(cache_file) # noqa: S301
except EOFError:
self.cache = {}
except FileNotFoundError:
self.cache = pickle.load(cache_file) # noqa: S301
except (FileNotFoundError, EOFError):
self.cache = {}

def _get_cache(self):
Expand Down Expand Up @@ -180,11 +169,12 @@ def set_entry(self, key, func_res):
}
if self.separate_files:
self._save_cache(key_data, key)
else:
with self.lock:
cache = self._get_cache()
cache[key] = key_data
self._save_cache(cache)
return # pragma: no cover

with self.lock:
cache = self._get_cache()
cache[key] = key_data
self._save_cache(cache)

def mark_entry_being_calculated_separate_files(self, key):
self._save_cache(
Expand All @@ -205,19 +195,20 @@ def mark_entry_not_calculated_separate_files(self, key):
def mark_entry_being_calculated(self, key):
if self.separate_files:
self.mark_entry_being_calculated_separate_files(key)
else:
with self.lock:
cache = self._get_cache()
try:
cache[key]["being_calculated"] = True
except KeyError:
cache[key] = {
"value": None,
"time": datetime.now(),
"stale": False,
"being_calculated": True,
}
self._save_cache(cache)
return # pragma: no cover

with self.lock:
cache = self._get_cache()
try:
cache[key]["being_calculated"] = True
except KeyError:
cache[key] = {
"value": None,
"time": datetime.now(),
"stale": False,
"being_calculated": True,
}
self._save_cache(cache)

def mark_entry_not_calculated(self, key):
if self.separate_files:
Expand Down Expand Up @@ -263,9 +254,10 @@ def clear_cache(self):
def clear_being_calculated(self):
if self.separate_files:
self._clear_being_calculated_all_cache_files()
else:
with self.lock:
cache = self._get_cache()
for key in cache:
cache[key]["being_calculated"] = False
self._save_cache(cache)
return # pragma: no cover

with self.lock:
cache = self._get_cache()
for key in cache:
cache[key]["being_calculated"] = False
self._save_cache(cache)
97 changes: 71 additions & 26 deletions tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,44 +344,89 @@ def dummy_func(a, b=2):
assert count == 1


def test_runtime_handling(tmpdir):
count = 0
@pytest.mark.parametrize("backend", ["memory", "pickle"])
def test_diff_functions_same_args(tmpdir, backend: str):
count_p = count_m = 0

@cachier.cachier(cache_dir=tmpdir, backend=backend)
def fn_plus(a, b=2):
nonlocal count_p
count_p += 1
return a + b

def dummy_func(a, b):
nonlocal count
count += 1
@cachier.cachier(cache_dir=tmpdir, backend=backend)
def fn_minus(a, b=2):
nonlocal count_m
count_m += 1
return a - b

assert count_p == count_m == 0

for fn, expected in [(fn_plus, 3), (fn_minus, -1)]:
assert fn(1) == expected
assert fn(a=1, b=2) == expected
assert count_p == 1
assert count_m == 1


@pytest.mark.parametrize("backend", ["memory", "pickle"])
def test_runtime_handling(tmpdir, backend):
count_p = count_m = 0

def fn_plus(a, b=2):
nonlocal count_p
count_p += 1
return a + b

cachier_ = cachier.cachier(cache_dir=tmpdir)
assert count == 0
cachier_(dummy_func)(a=1, b=2)
cachier_(dummy_func)(a=1, b=2)
assert count == 1
def fn_minus(a, b=2):
nonlocal count_m
count_m += 1
return a - b

cachier_ = cachier.cachier(cache_dir=tmpdir, backend=backend)
assert count_p == count_m == 0

for fn, expected in [(fn_plus, 3), (fn_minus, -1)]:
assert cachier_(fn)(1, 2) == expected
assert cachier_(fn)(a=1, b=2) == expected
assert count_p == 1
assert count_m == 1

for fn, expected in [(fn_plus, 5), (fn_minus, 1)]:
assert cachier_(fn)(3, 2) == expected
assert cachier_(fn)(a=3, b=2) == expected
assert count_p == 2
assert count_m == 2


def test_partial_handling(tmpdir):
count = 0
count_p = count_m = 0

def dummy_func(a, b=2):
nonlocal count
count += 1
def fn_plus(a, b=2):
nonlocal count_p
count_p += 1
return a + b

cachier_ = cachier.cachier(cache_dir=tmpdir)
assert count == 0
def fn_minus(a, b=2):
nonlocal count_m
count_m += 1
return a - b

dummy_ = functools.partial(dummy_func, 1)
cachier_(dummy_)()
cachier_ = cachier.cachier(cache_dir=tmpdir)
assert count_p == count_m == 0

dummy_ = functools.partial(dummy_func, a=1)
cachier_(dummy_)()
for fn, expected in [(fn_plus, 3), (fn_minus, -1)]:
dummy_ = functools.partial(fn, 1)
assert cachier_(dummy_)() == expected

dummy_ = functools.partial(dummy_func, b=2)
cachier_(dummy_)(1)
dummy_ = functools.partial(fn, a=1)
assert cachier_(dummy_)() == expected

assert count == 1
dummy_ = functools.partial(fn, b=2)
assert cachier_(dummy_)(1) == expected

cachier_(dummy_func)(1, 2)
cachier_(dummy_func)(a=1, b=2)
assert cachier_(fn)(1, 2) == expected
assert cachier_(fn)(a=1, b=2) == expected

assert count == 1
assert count_p == 1
assert count_m == 1