Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.9.2
current_version = 1.10.0
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
Expand Down
2 changes: 1 addition & 1 deletion nwastdlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#
"""The NWA-stdlib module."""

__version__ = "1.9.2"
__version__ = "1.10.0"

from nwastdlib.f import const, identity

Expand Down
100 changes: 91 additions & 9 deletions nwastdlib/asyncio_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import hashlib
import hmac
import inspect
import pickle # noqa: S403
import sys
import types
import typing
import warnings
from collections.abc import Callable
from functools import wraps
from typing import Any, Protocol, runtime_checkable
from typing import Any, Protocol, get_args, get_origin, runtime_checkable
from uuid import UUID

import structlog
from redis.asyncio import Redis as AIORedis
Expand Down Expand Up @@ -111,6 +117,80 @@
return _deserialize(pickled_value, serializer)


def _generate_cache_key_suffix(*, skip_first: bool, args: tuple, kwargs: dict) -> str:
# Auto generate cache key suffix based on the arguments
# Note: this makes no attempt to handle non-hashable values like lists and sets or other complex objects
filtered_args = args[int(skip_first) :]
filtered_kwargs = frozenset(kwargs.items())
if not filtered_args and not filtered_kwargs:
raise ValueError("Cannot generate cache key without args/kwargs")
args_and_kwargs_string = (filtered_args, filtered_kwargs)
return str(args_and_kwargs_string)


SAFE_CACHED_RESULT_TYPES = (
int,
str,
float,
datetime.datetime,
UUID,
)


def _unwrap_type(type_: Any) -> Any:
origin, args = get_origin(type_), get_args(type_)
# 'str'
if not origin:
return type_

# 'str | None' or 'Optional[str]'
if origin in (types.UnionType, typing.Union) and types.NoneType in args:
return args[0]

# For more advanced type handling, see https://github.com/workfloworchestrator/nwa-stdlib/issues/45
return type_


def _format_warning(func: Callable, name: str, type_: Any) -> str:
safe_types = (t.__name__ for t in SAFE_CACHED_RESULT_TYPES)
return (
f"{cached_result.__name__}() applied to function {func.__qualname__} which has parameter '{name}' "
f"of unsafe type '{type_.__name__}'. "
f"This can lead to duplicate keys and thus cache misses. "
f"To resolve this, either set a static keyname or only use parameters of the type {safe_types}. "
f"If you understand the risks you can suppress/ignore this warning. "
f"For background and feedback see https://github.com/workfloworchestrator/nwa-stdlib/issues/45"
)


def _validate_signature(func: Callable) -> bool:
"""Validate the function's signature and return a bool whether to skip the first argument.

Raises warnings for potentially unsafe cache key arguments.
"""
func_params = inspect.signature(func).parameters
is_nested_function = "." in func.__qualname__

skip_first_arg = False
for idx, (name, param) in enumerate(func_params.items()):
if idx == 0 and name == "self" and is_nested_function:
# This will falsely recognize a closure function with 'self'
# as first arg as a method. Nothing we can do about that..
skip_first_arg = True
continue

Check warning on line 180 in nwastdlib/asyncio_cache.py

View check run for this annotation

Codecov / codecov/patch

nwastdlib/asyncio_cache.py#L179-L180

Added lines #L179 - L180 were not covered by tests

param_type = _unwrap_type(param.annotation)
if param_type not in SAFE_CACHED_RESULT_TYPES:
warnings.warn(_format_warning(func, name, param.annotation), stacklevel=2)
return skip_first_arg


def _validate_coroutine(func: Callable) -> None:
"""Validate that the callable is a coroutine."""
if not inspect.iscoroutinefunction(func):
raise TypeError(f"Can't apply {cached_result.__name__}() to {func.__name__}: not a coroutine")


def cached_result(
pool: AIORedis,
prefix: str,
Expand Down Expand Up @@ -157,21 +237,23 @@
decorator function

"""
python_major, python_minor = sys.version_info[:2]
prefix_version = f"{prefix}:{python_major}.{python_minor}"
static_cache_key: str | None = f"{prefix_version}:{key_name}" if key_name else None

def cache_decorator(func: Callable) -> Callable:
_validate_coroutine(func)
skip_first = _validate_signature(func)

@wraps(func)
async def func_wrapper(*args: tuple[Any], **kwargs: dict[str, Any]) -> Any:
from_cache = (not revalidate_fn(*args, **kwargs)) if revalidate_fn else True

python_major, python_minor = sys.version_info[:2]
if key_name:
cache_key = f"{prefix}:{python_major}.{python_minor}:{key_name}"
if static_cache_key:
cache_key = static_cache_key
else:
# Auto generate a cache key name based on function_name and a hash of the arguments
# Note: this makes no attempt to handle non-hashable values like lists and sets or other complex objects
args_and_kwargs_string = (args, frozenset(kwargs.items()))
cache_key = f"{prefix}:{python_major}.{python_minor}:{func.__name__}{args_and_kwargs_string}"
logger.debug("Autogenerated a cache key", cache_key=cache_key)
suffix = _generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs)
cache_key = f"{prefix_version}:{func.__name__}:{suffix}"

if from_cache:
logger.debug("Cache called with wrapper func", func_name=func.__name__, cache_key=cache_key)
Expand Down
103 changes: 102 additions & 1 deletion tests/test_asyncio_cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import sys
from copy import copy
from datetime import datetime
from typing import Any, Optional, Union
from uuid import UUID

import pytest
from fakeredis.aioredis import FakeRedis

from nwastdlib.asyncio_cache import cached_result
from nwastdlib.asyncio_cache import _generate_cache_key_suffix, cached_result


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -220,3 +223,101 @@ async def slow_function(revalidate_cache: bool):
# A new call should serve 1: as it is not cached now
result = await slow_function(revalidate_cache=True)
assert result == 1


# Test the validation


@pytest.mark.parametrize(
"type_",
[
Any,
tuple,
Union[str, int],
],
)
def test_validate_signature_warn_unsafe(type_):
with pytest.warns(UserWarning, match="unsafe type"):

@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
async def foo(param: type_):
return f"{param}-{param}"


@pytest.mark.parametrize(
"type_",
[
int,
int | None,
Optional[int],
],
)
def test_validate_signature_safe(recwarn, type_):
@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
async def foo(param: type_):
return f"{param}-{param}"

assert not [w.message for w in recwarn]


def test_type_error_on_function():
with pytest.raises(TypeError, match="foo: not a coroutine"):

@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
def foo(param: str):
return f"{param}-{param}"


def test_type_error_on_generatorfunction():
with pytest.raises(TypeError, match="foo: not a coroutine"):

@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
def foo(param: str):
yield f"{param}-{param}"


def test_type_error_on_asyncgeneratorfunction():
with pytest.raises(TypeError, match="foo: not a coroutine"):

@cached_result(FakeRedis(), "test-suite", "SECRETNAME")
async def foo(param: str):
yield f"{param}-{param}"


# Test key generation

version = f"{sys.version_info.major}.{sys.version_info.minor}"
cache_prefix = "test"
cache_key_start = f"{cache_prefix}:{version}"


@pytest.mark.parametrize(
("skip_first", "args", "kwargs", "expected_key"),
[
(True, (1, 2), {}, "((2,), frozenset())"),
(False, (1, 2), {}, "((1, 2), frozenset())"),
(False, (1, "a"), {}, "((1, 'a'), frozenset())"),
(False, (), {"foo": "bar"}, "((), frozenset({('foo', 'bar')}))"),
(False, (1.234567,), {}, "((1.234567,), frozenset())"),
(False, (datetime(year=2025, month=4, day=14),), {}, "((datetime.datetime(2025, 4, 14, 0, 0),), frozenset())"),
(
False,
(UUID("12345678-0000-1111-2222-0123456789ab"),),
{},
"((UUID('12345678-0000-1111-2222-0123456789ab'),), frozenset())",
),
],
)
def test_generate_cache_key_suffix(skip_first, args, kwargs, expected_key):
assert _generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs) == expected_key


@pytest.mark.parametrize(
("skip_first", "args", "kwargs", "expected_exception"),
[
(False, (), {}, ValueError),
],
)
def test_generate_cache_key_errors(skip_first, args, kwargs, expected_exception):
with pytest.raises(expected_exception):
_generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs)