diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 1c6c0f5..065264c 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.9.2 +current_version = 1.10.0 commit = False tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? diff --git a/nwastdlib/__init__.py b/nwastdlib/__init__.py index 769e18c..6253744 100644 --- a/nwastdlib/__init__.py +++ b/nwastdlib/__init__.py @@ -13,7 +13,7 @@ # """The NWA-stdlib module.""" -__version__ = "1.9.2" +__version__ = "1.10.0" from nwastdlib.f import const, identity diff --git a/nwastdlib/asyncio_cache.py b/nwastdlib/asyncio_cache.py index 997b267..22d7653 100644 --- a/nwastdlib/asyncio_cache.py +++ b/nwastdlib/asyncio_cache.py @@ -11,13 +11,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import datetime import hashlib import hmac +import inspect import pickle # noqa: S403 import sys +import types +import typing +import warnings from collections.abc import Callable from functools import wraps -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, get_args, get_origin, runtime_checkable +from uuid import UUID import structlog from redis.asyncio import Redis as AIORedis @@ -111,6 +117,80 @@ async def get_signed_cache_value(pool: AIORedis, secret: str, cache_key: str, se return _deserialize(pickled_value, serializer) +def _generate_cache_key_suffix(*, skip_first: bool, args: tuple, kwargs: dict) -> str: + # Auto generate cache key suffix based on the arguments + # Note: this makes no attempt to handle non-hashable values like lists and sets or other complex objects + filtered_args = args[int(skip_first) :] + filtered_kwargs = frozenset(kwargs.items()) + if not filtered_args and not filtered_kwargs: + raise ValueError("Cannot generate cache key without args/kwargs") + args_and_kwargs_string = (filtered_args, filtered_kwargs) + return str(args_and_kwargs_string) + + +SAFE_CACHED_RESULT_TYPES = ( + int, + str, + float, + datetime.datetime, + UUID, +) + + +def _unwrap_type(type_: Any) -> Any: + origin, args = get_origin(type_), get_args(type_) + # 'str' + if not origin: + return type_ + + # 'str | None' or 'Optional[str]' + if origin in (types.UnionType, typing.Union) and types.NoneType in args: + return args[0] + + # For more advanced type handling, see https://github.com/workfloworchestrator/nwa-stdlib/issues/45 + return type_ + + +def _format_warning(func: Callable, name: str, type_: Any) -> str: + safe_types = (t.__name__ for t in SAFE_CACHED_RESULT_TYPES) + return ( + f"{cached_result.__name__}() applied to function {func.__qualname__} which has parameter '{name}' " + f"of unsafe type '{type_.__name__}'. " + f"This can lead to duplicate keys and thus cache misses. " + f"To resolve this, either set a static keyname or only use parameters of the type {safe_types}. " + f"If you understand the risks you can suppress/ignore this warning. " + f"For background and feedback see https://github.com/workfloworchestrator/nwa-stdlib/issues/45" + ) + + +def _validate_signature(func: Callable) -> bool: + """Validate the function's signature and return a bool whether to skip the first argument. + + Raises warnings for potentially unsafe cache key arguments. + """ + func_params = inspect.signature(func).parameters + is_nested_function = "." in func.__qualname__ + + skip_first_arg = False + for idx, (name, param) in enumerate(func_params.items()): + if idx == 0 and name == "self" and is_nested_function: + # This will falsely recognize a closure function with 'self' + # as first arg as a method. Nothing we can do about that.. + skip_first_arg = True + continue + + param_type = _unwrap_type(param.annotation) + if param_type not in SAFE_CACHED_RESULT_TYPES: + warnings.warn(_format_warning(func, name, param.annotation), stacklevel=2) + return skip_first_arg + + +def _validate_coroutine(func: Callable) -> None: + """Validate that the callable is a coroutine.""" + if not inspect.iscoroutinefunction(func): + raise TypeError(f"Can't apply {cached_result.__name__}() to {func.__name__}: not a coroutine") + + def cached_result( pool: AIORedis, prefix: str, @@ -157,21 +237,23 @@ def my_other_function... decorator function """ + python_major, python_minor = sys.version_info[:2] + prefix_version = f"{prefix}:{python_major}.{python_minor}" + static_cache_key: str | None = f"{prefix_version}:{key_name}" if key_name else None def cache_decorator(func: Callable) -> Callable: + _validate_coroutine(func) + skip_first = _validate_signature(func) + @wraps(func) async def func_wrapper(*args: tuple[Any], **kwargs: dict[str, Any]) -> Any: from_cache = (not revalidate_fn(*args, **kwargs)) if revalidate_fn else True - python_major, python_minor = sys.version_info[:2] - if key_name: - cache_key = f"{prefix}:{python_major}.{python_minor}:{key_name}" + if static_cache_key: + cache_key = static_cache_key else: - # Auto generate a cache key name based on function_name and a hash of the arguments - # Note: this makes no attempt to handle non-hashable values like lists and sets or other complex objects - args_and_kwargs_string = (args, frozenset(kwargs.items())) - cache_key = f"{prefix}:{python_major}.{python_minor}:{func.__name__}{args_and_kwargs_string}" - logger.debug("Autogenerated a cache key", cache_key=cache_key) + suffix = _generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs) + cache_key = f"{prefix_version}:{func.__name__}:{suffix}" if from_cache: logger.debug("Cache called with wrapper func", func_name=func.__name__, cache_key=cache_key) diff --git a/tests/test_asyncio_cache.py b/tests/test_asyncio_cache.py index 830f60e..174f004 100644 --- a/tests/test_asyncio_cache.py +++ b/tests/test_asyncio_cache.py @@ -1,11 +1,14 @@ import json import sys from copy import copy +from datetime import datetime +from typing import Any, Optional, Union +from uuid import UUID import pytest from fakeredis.aioredis import FakeRedis -from nwastdlib.asyncio_cache import cached_result +from nwastdlib.asyncio_cache import _generate_cache_key_suffix, cached_result @pytest.fixture(autouse=True) @@ -220,3 +223,101 @@ async def slow_function(revalidate_cache: bool): # A new call should serve 1: as it is not cached now result = await slow_function(revalidate_cache=True) assert result == 1 + + +# Test the validation + + +@pytest.mark.parametrize( + "type_", + [ + Any, + tuple, + Union[str, int], + ], +) +def test_validate_signature_warn_unsafe(type_): + with pytest.warns(UserWarning, match="unsafe type"): + + @cached_result(FakeRedis(), "test-suite", "SECRETNAME") + async def foo(param: type_): + return f"{param}-{param}" + + +@pytest.mark.parametrize( + "type_", + [ + int, + int | None, + Optional[int], + ], +) +def test_validate_signature_safe(recwarn, type_): + @cached_result(FakeRedis(), "test-suite", "SECRETNAME") + async def foo(param: type_): + return f"{param}-{param}" + + assert not [w.message for w in recwarn] + + +def test_type_error_on_function(): + with pytest.raises(TypeError, match="foo: not a coroutine"): + + @cached_result(FakeRedis(), "test-suite", "SECRETNAME") + def foo(param: str): + return f"{param}-{param}" + + +def test_type_error_on_generatorfunction(): + with pytest.raises(TypeError, match="foo: not a coroutine"): + + @cached_result(FakeRedis(), "test-suite", "SECRETNAME") + def foo(param: str): + yield f"{param}-{param}" + + +def test_type_error_on_asyncgeneratorfunction(): + with pytest.raises(TypeError, match="foo: not a coroutine"): + + @cached_result(FakeRedis(), "test-suite", "SECRETNAME") + async def foo(param: str): + yield f"{param}-{param}" + + +# Test key generation + +version = f"{sys.version_info.major}.{sys.version_info.minor}" +cache_prefix = "test" +cache_key_start = f"{cache_prefix}:{version}" + + +@pytest.mark.parametrize( + ("skip_first", "args", "kwargs", "expected_key"), + [ + (True, (1, 2), {}, "((2,), frozenset())"), + (False, (1, 2), {}, "((1, 2), frozenset())"), + (False, (1, "a"), {}, "((1, 'a'), frozenset())"), + (False, (), {"foo": "bar"}, "((), frozenset({('foo', 'bar')}))"), + (False, (1.234567,), {}, "((1.234567,), frozenset())"), + (False, (datetime(year=2025, month=4, day=14),), {}, "((datetime.datetime(2025, 4, 14, 0, 0),), frozenset())"), + ( + False, + (UUID("12345678-0000-1111-2222-0123456789ab"),), + {}, + "((UUID('12345678-0000-1111-2222-0123456789ab'),), frozenset())", + ), + ], +) +def test_generate_cache_key_suffix(skip_first, args, kwargs, expected_key): + assert _generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs) == expected_key + + +@pytest.mark.parametrize( + ("skip_first", "args", "kwargs", "expected_exception"), + [ + (False, (), {}, ValueError), + ], +) +def test_generate_cache_key_errors(skip_first, args, kwargs, expected_exception): + with pytest.raises(expected_exception): + _generate_cache_key_suffix(skip_first=skip_first, args=args, kwargs=kwargs)