diff --git a/cachier/cores/memory.py b/cachier/cores/memory.py index de87da00..f7fe83dc 100644 --- a/cachier/cores/memory.py +++ b/cachier/cores/memory.py @@ -4,7 +4,7 @@ from datetime import datetime from .._types import HashFunc -from .base import _BaseCore +from .base import _BaseCore, _get_func_str class _MemoryCore(_BaseCore): @@ -14,9 +14,12 @@ def __init__(self, hash_func: HashFunc, wait_for_calc_timeout: int): super().__init__(hash_func, wait_for_calc_timeout) self.cache = {} + def _hash_func_key(self, key): + return f"{_get_func_str(self.func)}:{key}" + def get_entry_by_key(self, key, reload=False): with self.lock: - return key, self.cache.get(key, None) + return key, self.cache.get(self._hash_func_key(key), None) def set_entry(self, key, func_res): with self.lock: @@ -24,10 +27,10 @@ def set_entry(self, key, func_res): # we need to retain the existing condition so that # mark_entry_not_calculated can notify all possibly-waiting # threads about it - cond = self.cache[key]["condition"] + cond = self.cache[self._hash_func_key(key)]["condition"] except KeyError: # pragma: no cover cond = None - self.cache[key] = { + self.cache[self._hash_func_key(key)] = { "value": func_res, "time": datetime.now(), "stale": False, @@ -40,10 +43,10 @@ def mark_entry_being_calculated(self, key): condition = threading.Condition() # condition.acquire() try: - self.cache[key]["being_calculated"] = True - self.cache[key]["condition"] = condition + self.cache[self._hash_func_key(key)]["being_calculated"] = True + self.cache[self._hash_func_key(key)]["condition"] = condition except KeyError: - self.cache[key] = { + self.cache[self._hash_func_key(key)] = { "value": None, "time": datetime.now(), "stale": False, @@ -54,7 +57,7 @@ def mark_entry_being_calculated(self, key): def mark_entry_not_calculated(self, key): with self.lock: try: - entry = self.cache[key] + entry = self.cache[self._hash_func_key(key)] except KeyError: # pragma: no cover return # that's ok, we don't need an entry in that case entry["being_calculated"] = False @@ -67,13 +70,13 @@ def mark_entry_not_calculated(self, key): def wait_on_entry_calc(self, key): with self.lock: # pragma: no cover - entry = self.cache[key] + entry = self.cache[self._hash_func_key(key)] if not entry["being_calculated"]: return entry["value"] entry["condition"].acquire() entry["condition"].wait() entry["condition"].release() - return self.cache[key]["value"] + return self.cache[self._hash_func_key(key)]["value"] def clear_cache(self): with self.lock: diff --git a/cachier/cores/pickle.py b/cachier/cores/pickle.py index 9479b54a..9b4a9fd2 100644 --- a/cachier/cores/pickle.py +++ b/cachier/cores/pickle.py @@ -86,26 +86,18 @@ def __init__( self.separate_files = _update_with_defaults( separate_files, "separate_files" ) - self._cache_fname = None - self._cache_fpath = None @property def cache_fname(self) -> str: - if self._cache_fname is None: - fname = f".{self.func.__module__}.{self.func.__qualname__}" - self._cache_fname = fname.replace("<", "_").replace(">", "_") - return self._cache_fname + fname = f".{self.func.__module__}.{self.func.__qualname__}" + return fname.replace("<", "_").replace(">", "_") @property def cache_fpath(self) -> str: - if self._cache_fpath is None: - os.makedirs(self.cache_dir, exist_ok=True) - self._cache_fpath = os.path.abspath( - os.path.join( - os.path.realpath(self.cache_dir), self.cache_fname - ) - ) - return self._cache_fpath + os.makedirs(self.cache_dir, exist_ok=True) + return os.path.abspath( + os.path.join(os.path.realpath(self.cache_dir), self.cache_fname) + ) def _reload_cache(self): with self.lock: @@ -113,11 +105,8 @@ def _reload_cache(self): with portalocker.Lock( self.cache_fpath, mode="rb" ) as cache_file: - try: - self.cache = pickle.load(cache_file) # noqa: S301 - except EOFError: - self.cache = {} - except FileNotFoundError: + self.cache = pickle.load(cache_file) # noqa: S301 + except (FileNotFoundError, EOFError): self.cache = {} def _get_cache(self): @@ -180,11 +169,12 @@ def set_entry(self, key, func_res): } if self.separate_files: self._save_cache(key_data, key) - else: - with self.lock: - cache = self._get_cache() - cache[key] = key_data - self._save_cache(cache) + return # pragma: no cover + + with self.lock: + cache = self._get_cache() + cache[key] = key_data + self._save_cache(cache) def mark_entry_being_calculated_separate_files(self, key): self._save_cache( @@ -205,19 +195,20 @@ def mark_entry_not_calculated_separate_files(self, key): def mark_entry_being_calculated(self, key): if self.separate_files: self.mark_entry_being_calculated_separate_files(key) - else: - with self.lock: - cache = self._get_cache() - try: - cache[key]["being_calculated"] = True - except KeyError: - cache[key] = { - "value": None, - "time": datetime.now(), - "stale": False, - "being_calculated": True, - } - self._save_cache(cache) + return # pragma: no cover + + with self.lock: + cache = self._get_cache() + try: + cache[key]["being_calculated"] = True + except KeyError: + cache[key] = { + "value": None, + "time": datetime.now(), + "stale": False, + "being_calculated": True, + } + self._save_cache(cache) def mark_entry_not_calculated(self, key): if self.separate_files: @@ -263,9 +254,10 @@ def clear_cache(self): def clear_being_calculated(self): if self.separate_files: self._clear_being_calculated_all_cache_files() - else: - with self.lock: - cache = self._get_cache() - for key in cache: - cache[key]["being_calculated"] = False - self._save_cache(cache) + return # pragma: no cover + + with self.lock: + cache = self._get_cache() + for key in cache: + cache[key]["being_calculated"] = False + self._save_cache(cache) diff --git a/tests/test_general.py b/tests/test_general.py index 6eba04b0..c3a3621d 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -344,44 +344,89 @@ def dummy_func(a, b=2): assert count == 1 -def test_runtime_handling(tmpdir): - count = 0 +@pytest.mark.parametrize("backend", ["memory", "pickle"]) +def test_diff_functions_same_args(tmpdir, backend: str): + count_p = count_m = 0 + + @cachier.cachier(cache_dir=tmpdir, backend=backend) + def fn_plus(a, b=2): + nonlocal count_p + count_p += 1 + return a + b - def dummy_func(a, b): - nonlocal count - count += 1 + @cachier.cachier(cache_dir=tmpdir, backend=backend) + def fn_minus(a, b=2): + nonlocal count_m + count_m += 1 + return a - b + + assert count_p == count_m == 0 + + for fn, expected in [(fn_plus, 3), (fn_minus, -1)]: + assert fn(1) == expected + assert fn(a=1, b=2) == expected + assert count_p == 1 + assert count_m == 1 + + +@pytest.mark.parametrize("backend", ["memory", "pickle"]) +def test_runtime_handling(tmpdir, backend): + count_p = count_m = 0 + + def fn_plus(a, b=2): + nonlocal count_p + count_p += 1 return a + b - cachier_ = cachier.cachier(cache_dir=tmpdir) - assert count == 0 - cachier_(dummy_func)(a=1, b=2) - cachier_(dummy_func)(a=1, b=2) - assert count == 1 + def fn_minus(a, b=2): + nonlocal count_m + count_m += 1 + return a - b + + cachier_ = cachier.cachier(cache_dir=tmpdir, backend=backend) + assert count_p == count_m == 0 + + for fn, expected in [(fn_plus, 3), (fn_minus, -1)]: + assert cachier_(fn)(1, 2) == expected + assert cachier_(fn)(a=1, b=2) == expected + assert count_p == 1 + assert count_m == 1 + + for fn, expected in [(fn_plus, 5), (fn_minus, 1)]: + assert cachier_(fn)(3, 2) == expected + assert cachier_(fn)(a=3, b=2) == expected + assert count_p == 2 + assert count_m == 2 def test_partial_handling(tmpdir): - count = 0 + count_p = count_m = 0 - def dummy_func(a, b=2): - nonlocal count - count += 1 + def fn_plus(a, b=2): + nonlocal count_p + count_p += 1 return a + b - cachier_ = cachier.cachier(cache_dir=tmpdir) - assert count == 0 + def fn_minus(a, b=2): + nonlocal count_m + count_m += 1 + return a - b - dummy_ = functools.partial(dummy_func, 1) - cachier_(dummy_)() + cachier_ = cachier.cachier(cache_dir=tmpdir) + assert count_p == count_m == 0 - dummy_ = functools.partial(dummy_func, a=1) - cachier_(dummy_)() + for fn, expected in [(fn_plus, 3), (fn_minus, -1)]: + dummy_ = functools.partial(fn, 1) + assert cachier_(dummy_)() == expected - dummy_ = functools.partial(dummy_func, b=2) - cachier_(dummy_)(1) + dummy_ = functools.partial(fn, a=1) + assert cachier_(dummy_)() == expected - assert count == 1 + dummy_ = functools.partial(fn, b=2) + assert cachier_(dummy_)(1) == expected - cachier_(dummy_func)(1, 2) - cachier_(dummy_func)(a=1, b=2) + assert cachier_(fn)(1, 2) == expected + assert cachier_(fn)(a=1, b=2) == expected - assert count == 1 + assert count_p == 1 + assert count_m == 1