Skip to content
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
hooks:
# use black formatting
- id: ruff-format
name: Black by Ruff
# basic check
- id: ruff
name: Ruff check
args: ["--fix"]
# use black formatting
- id: ruff-format
name: Black by Ruff
4 changes: 2 additions & 2 deletions cachier/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ._version import * # noqa: F403
from .core import (
cachier,
from .config import (
disable_caching,
enable_caching,
get_default_params,
set_default_params,
)
from .core import cachier

__all__ = [
"cachier",
Expand Down
9 changes: 9 additions & 0 deletions cachier/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import TYPE_CHECKING, Callable, Literal

if TYPE_CHECKING:
import pymongo.collection


HashFunc = Callable[..., str]
Mongetter = Callable[[], "pymongo.collection.Collection"]
Backend = Literal["pickle", "mongo", "memory"]
95 changes: 95 additions & 0 deletions cachier/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import datetime
import hashlib
import os
import pickle
from typing import Optional, TypedDict, Union

from ._types import Backend, HashFunc, Mongetter


def _default_hash_func(args, kwds):
# Sort the kwargs to ensure consistent ordering
sorted_kwargs = sorted(kwds.items())
# Serialize args and sorted_kwargs using pickle or similar
serialized = pickle.dumps((args, sorted_kwargs))
# Create a hash of the serialized data
return hashlib.sha256(serialized).hexdigest()


class Params(TypedDict):
caching_enabled: bool
hash_func: HashFunc
backend: Backend
mongetter: Optional[Mongetter]
stale_after: datetime.timedelta
next_time: bool
cache_dir: Union[str, os.PathLike]
pickle_reload: bool
separate_files: bool
wait_for_calc_timeout: int
allow_none: bool


_default_params: Params = {
"caching_enabled": True,
"hash_func": _default_hash_func,
"backend": "pickle",
"mongetter": None,
"stale_after": datetime.timedelta.max,
"next_time": False,
"cache_dir": "~/.cachier/",
"pickle_reload": True,
"separate_files": False,
"wait_for_calc_timeout": 0,
"allow_none": False,
}


def _update_with_defaults(param, name: str):
import cachier

if param is None:
return cachier.config._default_params[name]
return param


def set_default_params(**params):
"""Configure global parameters applicable to all memoized functions.

This function takes the same keyword parameters as the ones defined in the
decorator, which can be passed all at once or with multiple calls.
Parameters given directly to a decorator take precedence over any values
set by this function.

Only 'stale_after', 'next_time', and 'wait_for_calc_timeout' can be changed
after the memoization decorator has been applied. Other parameters will
only have an effect on decorators applied after this function is run.

"""
import cachier

valid_params = (
p for p in params.items() if p[0] in cachier.config._default_params
)
_default_params.update(valid_params)


def get_default_params():
"""Get current set of default parameters."""
import cachier

return cachier.config._default_params


def enable_caching():
"""Enable caching globally."""
import cachier

cachier.config._default_params["caching_enabled"] = True


def disable_caching():
"""Disable caching globally."""
import cachier

cachier.config._default_params["caching_enabled"] = False
126 changes: 17 additions & 109 deletions cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,27 @@
# http://www.opensource.org/licenses/MIT-license
# Copyright (c) 2016, Shay Palachy <shaypal5@gmail.com>

# python 2 compatibility

import datetime
import hashlib
import inspect
import os
import pickle
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import TYPE_CHECKING, Callable, Literal, Optional, TypedDict, Union
from typing import Optional, Union
from warnings import warn

from .config import (
Backend,
HashFunc,
Mongetter,
_default_params,
_update_with_defaults,
)
from .cores.base import RecalculationNeeded, _BaseCore
from .cores.memory import _MemoryCore
from .cores.mongo import _MongoCore
from .cores.pickle import _PickleCore

if TYPE_CHECKING:
import pymongo.collection


MAX_WORKERS_ENVAR_NAME = "CACHIER_MAX_WORKERS"
DEFAULT_MAX_WORKERS = 8

Expand Down Expand Up @@ -68,15 +67,6 @@ def _calc_entry(core, key, func, args, kwds):
core.mark_entry_not_calculated(key)


def _default_hash_func(args, kwds):
# Sort the kwargs to ensure consistent ordering
sorted_kwargs = sorted(kwds.items())
# Serialize args and sorted_kwargs using pickle or similar
serialized = pickle.dumps((args, sorted_kwargs))
# Create a hash of the serialized data
return hashlib.sha256(serialized).hexdigest()


def _convert_args_kwargs(
func, _is_method: bool, args: tuple, kwds: dict
) -> dict:
Expand All @@ -103,44 +93,6 @@ def _convert_args_kwargs(
return OrderedDict(sorted(kwargs.items()))


class MissingMongetter(ValueError):
"""Thrown when the mongetter keyword argument is missing."""


HashFunc = Callable[..., str]
Mongetter = Callable[[], "pymongo.collection.Collection"]
Backend = Literal["pickle", "mongo", "memory"]


class Params(TypedDict):
caching_enabled: bool
hash_func: HashFunc
backend: Backend
mongetter: Optional[Mongetter]
stale_after: datetime.timedelta
next_time: bool
cache_dir: Union[str, os.PathLike]
pickle_reload: bool
separate_files: bool
wait_for_calc_timeout: int
allow_none: bool


_default_params: Params = {
"caching_enabled": True,
"hash_func": _default_hash_func,
"backend": "pickle",
"mongetter": None,
"stale_after": datetime.timedelta.max,
"next_time": False,
"cache_dir": "~/.cachier/",
"pickle_reload": True,
"separate_files": False,
"wait_for_calc_timeout": 0,
"allow_none": False,
}


def cachier(
hash_func: Optional[HashFunc] = None,
hash_params: Optional[HashFunc] = None,
Expand Down Expand Up @@ -219,13 +171,12 @@ def cachier(
)
warn(message, DeprecationWarning, stacklevel=2)
hash_func = hash_params
# Update parameters with defaults if input is None
backend = _update_with_defaults(backend, "backend")
mongetter = _update_with_defaults(mongetter, "mongetter")
# Override the backend parameter if a mongetter is provided.
if mongetter is None:
mongetter = _default_params["mongetter"]
if callable(mongetter):
backend = "mongo"
if backend is None:
backend = _default_params["backend"]
core: _BaseCore
if backend == "pickle":
core = _PickleCore(
Expand All @@ -234,23 +185,16 @@ def cachier(
cache_dir=cache_dir,
separate_files=separate_files,
wait_for_calc_timeout=wait_for_calc_timeout,
default_params=_default_params,
)
elif backend == "mongo":
if mongetter is None:
raise MissingMongetter(
"must specify ``mongetter`` when using the mongo core"
)
core = _MongoCore(
mongetter=mongetter,
hash_func=hash_func,
mongetter=mongetter,
wait_for_calc_timeout=wait_for_calc_timeout,
default_params=_default_params,
)
elif backend == "memory":
core = _MemoryCore(
hash_func=hash_func,
default_params=_default_params,
hash_func=hash_func, wait_for_calc_timeout=wait_for_calc_timeout
)
else:
raise ValueError("specified an invalid core: %s" % backend)
Expand All @@ -261,11 +205,7 @@ def _cachier_decorator(func):
@wraps(func)
def func_wrapper(*args, **kwds):
nonlocal allow_none
_allow_none = (
allow_none
if allow_none is not None
else _default_params["allow_none"]
)
_allow_none = _update_with_defaults(allow_none, "allow_none")
# print('Inside general wrapper for {}.'.format(func.__name__))
ignore_cache = kwds.pop("ignore_cache", False)
overwrite_cache = kwds.pop("overwrite_cache", False)
Expand All @@ -289,10 +229,10 @@ def func_wrapper(*args, **kwds):
_print("Entry found.")
if _allow_none or entry.get("value", None) is not None:
_print("Cached result found.")
local_stale_after = (
stale_after or _default_params["stale_after"]
local_stale_after = _update_with_defaults(
stale_after, "stale_after"
)
local_next_time = next_time or _default_params["next_time"] # noqa: E501
local_next_time = _update_with_defaults(next_time, "next_time")
now = datetime.datetime.now()
if now - entry["time"] <= local_stale_after:
_print("And it is fresh!")
Expand Down Expand Up @@ -362,35 +302,3 @@ def precache_value(*args, value_to_cache, **kwds):
return func_wrapper

return _cachier_decorator


def set_default_params(**params):
"""Configure global parameters applicable to all memoized functions.

This function takes the same keyword parameters as the ones defined in the
decorator, which can be passed all at once or with multiple calls.
Parameters given directly to a decorator take precedence over any values
set by this function.

Only 'stale_after', 'next_time', and 'wait_for_calc_timeout' can be changed
after the memoization decorator has been applied. Other parameters will
only have an effect on decorators applied after this function is run.

"""
valid_params = (p for p in params.items() if p[0] in _default_params)
_default_params.update(valid_params)


def get_default_params():
"""Get current set of default parameters."""
return _default_params


def enable_caching():
"""Enable caching globally."""
_default_params["caching_enabled"] = True


def disable_caching():
"""Disable caching globally."""
_default_params["caching_enabled"] = False
23 changes: 12 additions & 11 deletions cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

import abc # for the _BaseCore abstract base class
import inspect
import threading

from .._types import HashFunc
from ..config import _update_with_defaults


class RecalculationNeeded(Exception):
Expand All @@ -17,9 +21,10 @@ class RecalculationNeeded(Exception):
class _BaseCore:
__metaclass__ = abc.ABCMeta

def __init__(self, hash_func, default_params):
self.default_params = default_params
self.hash_func = hash_func
def __init__(self, hash_func: HashFunc, wait_for_calc_timeout: int):
self.hash_func = _update_with_defaults(hash_func, "hash_func")
self.wait_for_calc_timeout = wait_for_calc_timeout
self.lock = threading.RLock()

def set_func(self, func):
"""Sets the function this core will use.
Expand All @@ -37,10 +42,7 @@ def set_func(self, func):

def get_key(self, args, kwds):
"""Returns a unique key based on the arguments provided."""
if self.hash_func is not None:
return self.hash_func(args, kwds)
else:
return self.default_params["hash_func"](args, kwds)
return self.hash_func(args, kwds)

def get_entry(self, args, kwds):
"""Returns the result mapped to the given arguments in this core's
Expand All @@ -56,10 +58,9 @@ def precache_value(self, args, kwds, value_to_cache):

def check_calc_timeout(self, time_spent):
"""Raise an exception if a recalculation is needed."""
if self.wait_for_calc_timeout is not None:
calc_timeout = self.wait_for_calc_timeout
else:
calc_timeout = self.default_params["wait_for_calc_timeout"]
calc_timeout = _update_with_defaults(
self.wait_for_calc_timeout, "wait_for_calc_timeout"
)
if calc_timeout > 0 and (time_spent >= calc_timeout):
raise RecalculationNeeded()

Expand Down
Loading