diff --git a/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py b/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py index 31de2849..06209361 100644 --- a/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py +++ b/key-value/key-value-aio/src/key_value/aio/adapters/pydantic/adapter.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any, Generic, SupportsFloat, TypeVar, get_origin +from typing import Any, Generic, SupportsFloat, TypeVar, get_origin, overload from key_value.shared.errors import DeserializationError, SerializationError from key_value.shared.type_checking.bear_spray import bear_spray @@ -71,11 +71,22 @@ def _serialize_model(self, value: T) -> dict[str, Any]: msg = f"Invalid Pydantic model: {e}" raise SerializationError(msg) from e - async def get(self, key: str, *, collection: str | None = None) -> T | None: + @overload + async def get(self, key: str, *, collection: str | None = None, default: T) -> T: ... + + @overload + async def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ... + + async def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None: """Get and validate a model by key. + Args: + key: The key to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return if the key doesn't exist or validation fails. + Returns: - The parsed model instance, or None if not present. + The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails. Raises: DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to @@ -84,15 +95,28 @@ async def get(self, key: str, *, collection: str | None = None) -> T | None: collection = collection or self._default_collection if value := await self._key_value.get(key=key, collection=collection): - return self._validate_model(value=value) + validated = self._validate_model(value=value) + if validated is not None: + return validated - return None + return default - async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[T | None]: + @overload + async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ... + + @overload + async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ... + + async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]: """Batch get and validate models by keys, preserving order. + Args: + keys: The list of keys to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return for keys that don't exist or fail validation. + Returns: - A list of parsed model instances, or None if missing. + A list of parsed model instances, with default values for missing keys or validation failures. Raises: DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to @@ -102,7 +126,14 @@ async def get_many(self, keys: Sequence[str], *, collection: str | None = None) values: list[dict[str, Any] | None] = await self._key_value.get_many(keys=keys, collection=collection) - return [self._validate_model(value=value) if value else None for value in values] + result: list[T | None] = [] + for value in values: + if value is None: + result.append(default) + else: + validated = self._validate_model(value=value) + result.append(validated if validated is not None else default) + return result async def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: """Serialize and store a model. diff --git a/key-value/key-value-aio/src/key_value/aio/stores/base.py b/key-value/key-value-aio/src/key_value/aio/stores/base.py index 9d75a3f3..5d0cf218 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/base.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/base.py @@ -6,7 +6,7 @@ from asyncio.locks import Lock from collections import defaultdict from collections.abc import Mapping, Sequence -from types import TracebackType +from types import MappingProxyType, TracebackType from typing import Any, SupportsFloat from key_value.shared.constants import DEFAULT_COLLECTION_NAME @@ -24,6 +24,16 @@ AsyncKeyValueProtocol, ) +SEED_DATA_TYPE = Mapping[str, Mapping[str, Mapping[str, Any]]] +FROZEN_SEED_DATA_TYPE = MappingProxyType[str, MappingProxyType[str, MappingProxyType[str, Any]]] +DEFAULT_SEED_DATA: FROZEN_SEED_DATA_TYPE = MappingProxyType({}) + + +def _seed_to_frozen_seed_data(seed: SEED_DATA_TYPE) -> FROZEN_SEED_DATA_TYPE: + return MappingProxyType( + {collection: MappingProxyType({key: MappingProxyType(value) for key, value in items.items()}) for collection, items in seed.items()} + ) + class BaseStore(AsyncKeyValueProtocol, ABC): """An opinionated Abstract base class for managed key-value stores using ManagedEntry objects. @@ -43,14 +53,19 @@ class BaseStore(AsyncKeyValueProtocol, ABC): _setup_collection_locks: defaultdict[str, Lock] _setup_collection_complete: defaultdict[str, bool] + _seed: FROZEN_SEED_DATA_TYPE + default_collection: str - def __init__(self, *, default_collection: str | None = None) -> None: + def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None: """Initialize the managed key-value store. Args: default_collection: The default collection to use if no collection is provided. Defaults to "default_collection". + seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. + Seeding occurs once during store initialization (when the store is first entered or when the + first operation is performed on the store). """ self._setup_complete = False @@ -58,6 +73,8 @@ def __init__(self, *, default_collection: str | None = None) -> None: self._setup_collection_locks = defaultdict(Lock) self._setup_collection_complete = defaultdict(bool) + self._seed = _seed_to_frozen_seed_data(seed=seed or {}) + self.default_collection = default_collection or DEFAULT_COLLECTION_NAME if not hasattr(self, "_stable_api"): @@ -74,6 +91,13 @@ async def _setup(self) -> None: async def _setup_collection(self, *, collection: str) -> None: """Initialize the collection (called once before first use of the collection).""" + async def _seed_store(self) -> None: + """Seed the store with the data from the seed.""" + for collection, items in self._seed.items(): + await self.setup_collection(collection=collection) + for key, value in items.items(): + await self.put(key=key, value=dict(value), collection=collection) + async def setup(self) -> None: if not self._setup_complete: async with self._setup_lock: @@ -84,8 +108,11 @@ async def setup(self) -> None: raise StoreSetupError( message=f"Failed to setup key value store: {e}", extra_info={"store": self.__class__.__name__} ) from e + self._setup_complete = True + await self._seed_store() + async def setup_collection(self, *, collection: str) -> None: await self.setup() diff --git a/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py b/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py index c7cdbe99..06e7dc1f 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py @@ -8,6 +8,7 @@ from typing_extensions import Self, override from key_value.aio.stores.base import ( + SEED_DATA_TYPE, BaseDestroyCollectionStore, BaseDestroyStore, BaseEnumerateCollectionsStore, @@ -109,12 +110,21 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol _cache: dict[str, MemoryCollection] - def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, default_collection: str | None = None): + def __init__( + self, + *, + max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, + default_collection: str | None = None, + seed: SEED_DATA_TYPE | None = None, + ): """Initialize a fixed-size in-memory store. Args: max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000. default_collection: The default collection to use if no collection is provided. + seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. + Each value must be a mapping (dict) that will be stored as the entry's value. + Seeding occurs lazily when each collection is first accessed. """ self.max_entries_per_collection = max_entries_per_collection @@ -123,11 +133,25 @@ def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_ self._stable_api = True - super().__init__(default_collection=default_collection) + super().__init__(default_collection=default_collection, seed=seed) + + @override + async def _setup(self) -> None: + for collection in self._seed: + await self._setup_collection(collection=collection) @override async def _setup_collection(self, *, collection: str) -> None: - self._cache[collection] = MemoryCollection(max_entries=self.max_entries_per_collection) + """Set up a collection, creating it and seeding it if seed data is available. + + Args: + collection: The collection name. + """ + if collection in self._cache: + return + + collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection) + self._cache[collection] = collection_cache @override async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/default_value/__init__.py b/key-value/key-value-aio/src/key_value/aio/wrappers/default_value/__init__.py new file mode 100644 index 00000000..e467fe5f --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/default_value/__init__.py @@ -0,0 +1,5 @@ +"""Default value wrapper for returning fallback values when keys are not found.""" + +from key_value.aio.wrappers.default_value.wrapper import DefaultValueWrapper + +__all__ = ["DefaultValueWrapper"] diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/default_value/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/default_value/wrapper.py new file mode 100644 index 00000000..81536d28 --- /dev/null +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/default_value/wrapper.py @@ -0,0 +1,68 @@ +from collections.abc import Mapping, Sequence +from typing import Any, SupportsFloat + +from key_value.shared.utils.managed_entry import dump_to_json, load_from_json +from typing_extensions import override + +from key_value.aio.protocols.key_value import AsyncKeyValue +from key_value.aio.wrappers.base import BaseWrapper + + +class DefaultValueWrapper(BaseWrapper): + """A wrapper that returns a default value when a key is not found. + + This wrapper provides dict.get(key, default) behavior for the key-value store, + allowing you to specify a default value to return instead of None when a key doesn't exist. + + It does not store the default value in the underlying key-value store and the TTL returned with the default + value is hard-coded based on the default_ttl parameter. Picking a default_ttl requires careful consideration + of how the value will be used and if any other wrappers will be used that may rely on the TTL. + """ + + key_value: AsyncKeyValue # Alias for BaseWrapper compatibility + _default_ttl: float | None + _default_value_json: str + + def __init__( + self, + key_value: AsyncKeyValue, + default_value: Mapping[str, Any], + default_ttl: SupportsFloat | None = None, + ) -> None: + """Initialize the DefaultValueWrapper. + + Args: + key_value: The underlying key-value store to wrap. + default_value: The default value to return when a key is not found. + default_ttl: The TTL to return to the caller for default values. Defaults to None. + """ + self.key_value = key_value + self._default_value_json = dump_to_json(obj=dict(default_value)) + self._default_ttl = None if default_ttl is None else float(default_ttl) + + def _new_default_value(self) -> dict[str, Any]: + return load_from_json(json_str=self._default_value_json) + + @override + async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + result = await self.key_value.get(key=key, collection=collection) + return result if result is not None else self._new_default_value() + + @override + async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + results = await self.key_value.get_many(keys=keys, collection=collection) + return [result if result is not None else self._new_default_value() for result in results] + + @override + async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + result, ttl_value = await self.key_value.ttl(key=key, collection=collection) + if result is None: + return (self._new_default_value(), self._default_ttl) + return (result, ttl_value) + + @override + async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + results = await self.key_value.ttl_many(keys=keys, collection=collection) + return [ + (result, ttl_value) if result is not None else (self._new_default_value(), self._default_ttl) for result, ttl_value in results + ] diff --git a/key-value/key-value-aio/tests/adapters/test_pydantic.py b/key-value/key-value-aio/tests/adapters/test_pydantic.py index d2886a61..6d7cdaef 100644 --- a/key-value/key-value-aio/tests/adapters/test_pydantic.py +++ b/key-value/key-value-aio/tests/adapters/test_pydantic.py @@ -38,11 +38,13 @@ class Order(BaseModel): FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) +SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10, url=AnyHttpUrl(url="https://example.com")) SAMPLE_ORDER: Order = Order(created_at=datetime.now(), updated_at=datetime.now(), user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) TEST_COLLECTION: str = "test_collection" TEST_KEY: str = "test_key" +TEST_KEY_2: str = "test_key_2" class TestPydanticAdapter: @@ -77,8 +79,17 @@ async def test_simple_adapter(self, user_adapter: PydanticAdapter[User]): assert await user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - cached_user = await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_user is None + assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + async def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]): + assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER + + await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) + assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 + + assert await user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( + [SAMPLE_USER_2, SAMPLE_USER] + ) async def test_simple_adapter_with_validation_error_ignore( self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] @@ -98,12 +109,10 @@ async def test_simple_adapter_with_validation_error_raise( async def test_complex_adapter(self, order_adapter: PydanticAdapter[Order]): await order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) - cached_order: Order | None = await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_order == SAMPLE_ORDER + assert await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) == SAMPLE_ORDER assert await order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - cached_order = await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_order is None + assert await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None async def test_complex_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore): await product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) @@ -127,5 +136,4 @@ async def test_complex_adapter_with_list(self, product_list_adapter: PydanticAda ) assert await product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - cached_products = await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_products is None + assert await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None diff --git a/key-value/key-value-aio/tests/stores/memory/test_memory.py b/key-value/key-value-aio/tests/stores/memory/test_memory.py index ffd21235..a98a2a31 100644 --- a/key-value/key-value-aio/tests/stores/memory/test_memory.py +++ b/key-value/key-value-aio/tests/stores/memory/test_memory.py @@ -10,3 +10,7 @@ class TestMemoryStore(BaseStoreTests): @pytest.fixture async def store(self) -> MemoryStore: return MemoryStore(max_entries_per_collection=500) + + async def test_seed(self): + store = MemoryStore(max_entries_per_collection=500, seed={"test_collection": {"test_key": {"obj_key": "obj_value"}}}) + assert await store.get(key="test_key", collection="test_collection") == {"obj_key": "obj_value"} diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_default_value.py b/key-value/key-value-aio/tests/stores/wrappers/test_default_value.py new file mode 100644 index 00000000..8f2b92cf --- /dev/null +++ b/key-value/key-value-aio/tests/stores/wrappers/test_default_value.py @@ -0,0 +1,91 @@ +import pytest +from dirty_equals import IsFloat +from typing_extensions import override + +from key_value.aio.stores.base import BaseStore +from key_value.aio.stores.memory.store import MemoryStore +from key_value.aio.wrappers.default_value import DefaultValueWrapper +from tests.stores.base import BaseStoreTests + +TEST_KEY_1 = "test_key_1" +TEST_KEY_2 = "test_key_2" +TEST_COLLECTION = "test_collection" +DEFAULT_VALUE = {"obj_key": "obj_value"} +DEFAULT_TTL = 100 + + +class TestDefaultValueWrapper(BaseStoreTests): + @override + @pytest.fixture + async def store(self, memory_store: MemoryStore) -> DefaultValueWrapper: + return DefaultValueWrapper(key_value=memory_store, default_value=DEFAULT_VALUE, default_ttl=DEFAULT_TTL) + + async def test_default_value(self, store: BaseStore): + assert await store.get(collection=TEST_COLLECTION, key=TEST_KEY_1) == DEFAULT_VALUE + assert await store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_1) == (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)) + assert await store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, DEFAULT_VALUE] + assert await store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + ] + + await store.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value={"key_2": "value_2"}, ttl=200) + assert await store.get(collection=TEST_COLLECTION, key=TEST_KEY_2) == {"key_2": "value_2"} + assert await store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_2) == ({"key_2": "value_2"}, IsFloat(approx=200)) + assert await store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, {"key_2": "value_2"}] + assert await store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + ({"key_2": "value_2"}, IsFloat(approx=200)), + ] + + @override + @pytest.mark.skip + async def test_empty_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_put_put_get_many_missing_one(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_empty_ttl(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_get_put_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_get_put_get_delete_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_put_get_delete_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_get_put_get_put_delete_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_put_many_delete_delete_get_many(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_put_expired_get_none(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_not_unbounded(self, store: BaseStore): ... + + @override + @pytest.mark.skip + async def test_concurrent_operations(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py index 125cd1af..d904a4db 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/adapters/pydantic/adapter.py @@ -2,7 +2,7 @@ # from the original file 'adapter.py' # DO NOT CHANGE! Change the original file instead. from collections.abc import Sequence -from typing import Any, Generic, SupportsFloat, TypeVar, get_origin +from typing import Any, Generic, SupportsFloat, TypeVar, get_origin, overload from key_value.shared.errors import DeserializationError, SerializationError from key_value.shared.type_checking.bear_spray import bear_spray @@ -71,11 +71,22 @@ def _serialize_model(self, value: T) -> dict[str, Any]: msg = f"Invalid Pydantic model: {e}" raise SerializationError(msg) from e - def get(self, key: str, *, collection: str | None = None) -> T | None: + @overload + def get(self, key: str, *, collection: str | None = None, default: T) -> T: ... + + @overload + def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ... + + def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None: """Get and validate a model by key. + Args: + key: The key to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return if the key doesn't exist or validation fails. + Returns: - The parsed model instance, or None if not present. + The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails. Raises: DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to @@ -84,15 +95,28 @@ def get(self, key: str, *, collection: str | None = None) -> T | None: collection = collection or self._default_collection if value := self._key_value.get(key=key, collection=collection): - return self._validate_model(value=value) + validated = self._validate_model(value=value) + if validated is not None: + return validated - return None + return default - def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[T | None]: + @overload + def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ... + + @overload + def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ... + + def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]: """Batch get and validate models by keys, preserving order. + Args: + keys: The list of keys to retrieve. + collection: The collection to use. If not provided, uses the default collection. + default: The default value to return for keys that don't exist or fail validation. + Returns: - A list of parsed model instances, or None if missing. + A list of parsed model instances, with default values for missing keys or validation failures. Raises: DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to @@ -102,7 +126,14 @@ def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> lis values: list[dict[str, Any] | None] = self._key_value.get_many(keys=keys, collection=collection) - return [self._validate_model(value=value) if value else None for value in values] + result: list[T | None] = [] + for value in values: + if value is None: + result.append(default) + else: + validated = self._validate_model(value=value) + result.append(validated if validated is not None else default) + return result def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None: """Serialize and store a model. diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py index 362e84af..0086afda 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from threading import Lock -from types import TracebackType +from types import MappingProxyType, TracebackType from typing import Any, SupportsFloat from key_value.shared.constants import DEFAULT_COLLECTION_NAME @@ -27,6 +27,19 @@ KeyValueProtocol, ) +SEED_DATA_TYPE = Mapping[str, Mapping[str, Mapping[str, Any]]] +FROZEN_SEED_DATA_TYPE = MappingProxyType[str, MappingProxyType[str, MappingProxyType[str, Any]]] +DEFAULT_SEED_DATA: FROZEN_SEED_DATA_TYPE = MappingProxyType({}) + + +def _seed_to_frozen_seed_data(seed: SEED_DATA_TYPE) -> FROZEN_SEED_DATA_TYPE: + return MappingProxyType( + { + collection: MappingProxyType({key: MappingProxyType(value) for (key, value) in items.items()}) + for (collection, items) in seed.items() + } + ) + class BaseStore(KeyValueProtocol, ABC): """An opinionated Abstract base class for managed key-value stores using ManagedEntry objects. @@ -46,14 +59,19 @@ class BaseStore(KeyValueProtocol, ABC): _setup_collection_locks: defaultdict[str, Lock] _setup_collection_complete: defaultdict[str, bool] + _seed: FROZEN_SEED_DATA_TYPE + default_collection: str - def __init__(self, *, default_collection: str | None = None) -> None: + def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None: """Initialize the managed key-value store. Args: default_collection: The default collection to use if no collection is provided. Defaults to "default_collection". + seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. + Seeding occurs once during store initialization (when the store is first entered or when the + first operation is performed on the store). """ self._setup_complete = False @@ -61,6 +79,8 @@ def __init__(self, *, default_collection: str | None = None) -> None: self._setup_collection_locks = defaultdict(Lock) self._setup_collection_complete = defaultdict(bool) + self._seed = _seed_to_frozen_seed_data(seed=seed or {}) + self.default_collection = default_collection or DEFAULT_COLLECTION_NAME if not hasattr(self, "_stable_api"): @@ -77,6 +97,13 @@ def _setup(self) -> None: def _setup_collection(self, *, collection: str) -> None: """Initialize the collection (called once before first use of the collection).""" + def _seed_store(self) -> None: + """Seed the store with the data from the seed.""" + for collection, items in self._seed.items(): + self.setup_collection(collection=collection) + for key, value in items.items(): + self.put(key=key, value=dict(value), collection=collection) + def setup(self) -> None: if not self._setup_complete: with self._setup_lock: @@ -87,8 +114,11 @@ def setup(self) -> None: raise StoreSetupError( message=f"Failed to setup key value store: {e}", extra_info={"store": self.__class__.__name__} ) from e + self._setup_complete = True + self._seed_store() + def setup_collection(self, *, collection: str) -> None: self.setup() diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py index d72280a7..dcb6febb 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py @@ -11,6 +11,7 @@ from typing_extensions import Self, override from key_value.sync.code_gen.stores.base import ( + SEED_DATA_TYPE, BaseDestroyCollectionStore, BaseDestroyStore, BaseEnumerateCollectionsStore, @@ -104,12 +105,21 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol _cache: dict[str, MemoryCollection] - def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, default_collection: str | None = None): + def __init__( + self, + *, + max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, + default_collection: str | None = None, + seed: SEED_DATA_TYPE | None = None, + ): """Initialize a fixed-size in-memory store. Args: max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000. default_collection: The default collection to use if no collection is provided. + seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. + Each value must be a mapping (dict) that will be stored as the entry's value. + Seeding occurs lazily when each collection is first accessed. """ self.max_entries_per_collection = max_entries_per_collection @@ -118,11 +128,25 @@ def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_ self._stable_api = True - super().__init__(default_collection=default_collection) + super().__init__(default_collection=default_collection, seed=seed) + + @override + def _setup(self) -> None: + for collection in self._seed: + self._setup_collection(collection=collection) @override def _setup_collection(self, *, collection: str) -> None: - self._cache[collection] = MemoryCollection(max_entries=self.max_entries_per_collection) + """Set up a collection, creating it and seeding it if seed data is available. + + Args: + collection: The collection name. + """ + if collection in self._cache: + return + + collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection) + self._cache[collection] = collection_cache @override def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/default_value/__init__.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/default_value/__init__.py new file mode 100644 index 00000000..2246cc63 --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/default_value/__init__.py @@ -0,0 +1,8 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +"""Default value wrapper for returning fallback values when keys are not found.""" + +from key_value.sync.code_gen.wrappers.default_value.wrapper import DefaultValueWrapper + +__all__ = ["DefaultValueWrapper"] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/default_value/wrapper.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/default_value/wrapper.py new file mode 100644 index 00000000..d1a8f891 --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/default_value/wrapper.py @@ -0,0 +1,66 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file 'wrapper.py' +# DO NOT CHANGE! Change the original file instead. +from collections.abc import Mapping, Sequence +from typing import Any, SupportsFloat + +from key_value.shared.utils.managed_entry import dump_to_json, load_from_json +from typing_extensions import override + +from key_value.sync.code_gen.protocols.key_value import KeyValue +from key_value.sync.code_gen.wrappers.base import BaseWrapper + + +class DefaultValueWrapper(BaseWrapper): + """A wrapper that returns a default value when a key is not found. + + This wrapper provides dict.get(key, default) behavior for the key-value store, + allowing you to specify a default value to return instead of None when a key doesn't exist. + + It does not store the default value in the underlying key-value store and the TTL returned with the default + value is hard-coded based on the default_ttl parameter. Picking a default_ttl requires careful consideration + of how the value will be used and if any other wrappers will be used that may rely on the TTL. + """ + + key_value: KeyValue # Alias for BaseWrapper compatibility + _default_ttl: float | None + _default_value_json: str + + def __init__(self, key_value: KeyValue, default_value: Mapping[str, Any], default_ttl: SupportsFloat | None = None) -> None: + """Initialize the DefaultValueWrapper. + + Args: + key_value: The underlying key-value store to wrap. + default_value: The default value to return when a key is not found. + default_ttl: The TTL to return to the caller for default values. Defaults to None. + """ + self.key_value = key_value + self._default_value_json = dump_to_json(obj=dict(default_value)) + self._default_ttl = None if default_ttl is None else float(default_ttl) + + def _new_default_value(self) -> dict[str, Any]: + return load_from_json(json_str=self._default_value_json) + + @override + def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + result = self.key_value.get(key=key, collection=collection) + return result if result is not None else self._new_default_value() + + @override + def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]: + results = self.key_value.get_many(keys=keys, collection=collection) + return [result if result is not None else self._new_default_value() for result in results] + + @override + def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]: + (result, ttl_value) = self.key_value.ttl(key=key, collection=collection) + if result is None: + return (self._new_default_value(), self._default_ttl) + return (result, ttl_value) + + @override + def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]: + results = self.key_value.ttl_many(keys=keys, collection=collection) + return [ + (result, ttl_value) if result is not None else (self._new_default_value(), self._default_ttl) for (result, ttl_value) in results + ] diff --git a/key-value/key-value-sync/src/key_value/sync/wrappers/default_value/__init__.py b/key-value/key-value-sync/src/key_value/sync/wrappers/default_value/__init__.py new file mode 100644 index 00000000..2246cc63 --- /dev/null +++ b/key-value/key-value-sync/src/key_value/sync/wrappers/default_value/__init__.py @@ -0,0 +1,8 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file '__init__.py' +# DO NOT CHANGE! Change the original file instead. +"""Default value wrapper for returning fallback values when keys are not found.""" + +from key_value.sync.code_gen.wrappers.default_value.wrapper import DefaultValueWrapper + +__all__ = ["DefaultValueWrapper"] diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py index 50899127..3cef2742 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py @@ -41,11 +41,13 @@ class Order(BaseModel): FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) +SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10, url=AnyHttpUrl(url="https://example.com")) SAMPLE_ORDER: Order = Order(created_at=datetime.now(), updated_at=datetime.now(), user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) TEST_COLLECTION: str = "test_collection" TEST_KEY: str = "test_key" +TEST_KEY_2: str = "test_key_2" class TestPydanticAdapter: @@ -80,8 +82,17 @@ def test_simple_adapter(self, user_adapter: PydanticAdapter[User]): assert user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - cached_user = user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_user is None + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + + def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]): + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER + + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 + + assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( + [SAMPLE_USER_2, SAMPLE_USER] + ) def test_simple_adapter_with_validation_error_ignore( self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] @@ -101,12 +112,10 @@ def test_simple_adapter_with_validation_error_raise( def test_complex_adapter(self, order_adapter: PydanticAdapter[Order]): order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) - cached_order: Order | None = order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_order == SAMPLE_ORDER + assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) == SAMPLE_ORDER assert order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - cached_order = order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_order is None + assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None def test_complex_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore): product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) @@ -130,5 +139,4 @@ def test_complex_adapter_with_list(self, product_list_adapter: PydanticAdapter[l ) assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - cached_products = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - assert cached_products is None + assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None diff --git a/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py b/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py index 3b75dec3..631cf255 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py +++ b/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py @@ -13,3 +13,7 @@ class TestMemoryStore(BaseStoreTests): @pytest.fixture def store(self) -> MemoryStore: return MemoryStore(max_entries_per_collection=500) + + def test_seed(self): + store = MemoryStore(max_entries_per_collection=500, seed={"test_collection": {"test_key": {"obj_key": "obj_value"}}}) + assert store.get(key="test_key", collection="test_collection") == {"obj_key": "obj_value"} diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py new file mode 100644 index 00000000..40323f9f --- /dev/null +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py @@ -0,0 +1,94 @@ +# WARNING: this file is auto-generated by 'build_sync_library.py' +# from the original file 'test_default_value.py' +# DO NOT CHANGE! Change the original file instead. +import pytest +from dirty_equals import IsFloat +from typing_extensions import override + +from key_value.sync.code_gen.stores.base import BaseStore +from key_value.sync.code_gen.stores.memory.store import MemoryStore +from key_value.sync.code_gen.wrappers.default_value import DefaultValueWrapper +from tests.code_gen.stores.base import BaseStoreTests + +TEST_KEY_1 = "test_key_1" +TEST_KEY_2 = "test_key_2" +TEST_COLLECTION = "test_collection" +DEFAULT_VALUE = {"obj_key": "obj_value"} +DEFAULT_TTL = 100 + + +class TestDefaultValueWrapper(BaseStoreTests): + @override + @pytest.fixture + def store(self, memory_store: MemoryStore) -> DefaultValueWrapper: + return DefaultValueWrapper(key_value=memory_store, default_value=DEFAULT_VALUE, default_ttl=DEFAULT_TTL) + + def test_default_value(self, store: BaseStore): + assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_1) == DEFAULT_VALUE + assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_1) == (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)) + assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, DEFAULT_VALUE] + assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + ] + + store.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value={"key_2": "value_2"}, ttl=200) + assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_2) == {"key_2": "value_2"} + assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_2) == ({"key_2": "value_2"}, IsFloat(approx=200)) + assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, {"key_2": "value_2"}] + assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + ({"key_2": "value_2"}, IsFloat(approx=200)), + ] + + @override + @pytest.mark.skip + def test_empty_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_put_put_get_many_missing_one(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_empty_ttl(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_get_put_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_get_put_get_delete_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_put_get_delete_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_get_put_get_put_delete_get(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_put_many_delete_delete_get_many(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_put_expired_get_none(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_not_unbounded(self, store: BaseStore): ... + + @override + @pytest.mark.skip + def test_concurrent_operations(self, store: BaseStore): ...