diff --git a/key-value/key-value-aio/src/key_value/aio/stores/base.py b/key-value/key-value-aio/src/key_value/aio/stores/base.py index 074f7850..94752d75 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/base.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/base.py @@ -14,6 +14,7 @@ from key_value.shared.errors import StoreSetupError from key_value.shared.type_checking.bear_spray import bear_enforce from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from key_value.shared.utils.time_to_live import prepare_entry_timestamps from typing_extensions import Self, override @@ -67,14 +68,23 @@ class BaseStore(AsyncKeyValueProtocol, ABC): _setup_collection_locks: defaultdict[str, Lock] _setup_collection_complete: defaultdict[str, bool] + _serialization_adapter: SerializationAdapter + _seed: FROZEN_SEED_DATA_TYPE default_collection: str - def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None: + def __init__( + self, + *, + serialization_adapter: SerializationAdapter | None = None, + default_collection: str | None = None, + seed: SEED_DATA_TYPE | None = None, + ) -> None: """Initialize the managed key-value store. Args: + serialization_adapter: The serialization adapter to use for the store. default_collection: The default collection to use if no collection is provided. Defaults to "default_collection". seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. @@ -91,6 +101,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP self.default_collection = default_collection or DEFAULT_COLLECTION_NAME + self._serialization_adapter = serialization_adapter or BasicSerializationAdapter() + if not hasattr(self, "_stable_api"): self._stable_api = False @@ -286,9 +298,9 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non collection = collection or self.default_collection await self.setup_collection(collection=collection) - created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl) + created_at, _, expires_at = prepare_entry_timestamps(ttl=ttl) - managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) + managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) await self._put_managed_entry( collection=collection, @@ -316,9 +328,7 @@ async def put_many( created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl) - managed_entries: list[ManagedEntry] = [ - ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values - ] + managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values] await self._put_managed_entries( collection=collection, diff --git a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py index 5ff14b9d..4c3ee1ed 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py @@ -1,10 +1,10 @@ -import time from collections.abc import Callable +from datetime import timezone from pathlib import Path from typing import overload -from key_value.shared.utils.compound import compound_key -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, datetime +from key_value.shared.utils.serialization import BasicSerializationAdapter from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseStore @@ -100,6 +100,7 @@ def default_disk_cache_factory(collection: str) -> Cache: self._cache = {} self._stable_api = True + self._serialization_adapter = BasicSerializationAdapter() super().__init__(default_collection=default_collection) @@ -109,18 +110,17 @@ async def _setup_collection(self, *, collection: str) -> None: @override async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: - combo_key: str = compound_key(collection=collection, key=key) - expire_epoch: float - managed_entry_str, expire_epoch = self._cache[collection].get(key=combo_key, expire_time=True) # pyright: ignore[reportAny] + managed_entry_str, expire_epoch = self._cache[collection].get(key=key, expire_time=True) # pyright: ignore[reportAny] if not isinstance(managed_entry_str, str): return None - ttl = (expire_epoch - time.time()) if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry @@ -132,15 +132,11 @@ async def _put_managed_entry( collection: str, managed_entry: ManagedEntry, ) -> None: - combo_key: str = compound_key(collection=collection, key=key) - - _ = self._cache[collection].set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override async def _delete_managed_entry(self, *, key: str, collection: str) -> bool: - combo_key: str = compound_key(collection=collection, key=key) - - return self._cache[collection].delete(key=combo_key, retry=True) + return self._cache[collection].delete(key=key, retry=True) def _sync_close(self) -> None: for cache in self._cache.values(): diff --git a/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py b/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py index 8ae36cf6..f9a00f0b 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py @@ -1,4 +1,4 @@ -import time +from datetime import datetime, timezone from pathlib import Path from typing import overload @@ -90,9 +90,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not isinstance(managed_entry_str, str): return None - ttl = (expire_epoch - time.time()) if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry @@ -106,7 +107,7 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - _ = self._cache.set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override async def _delete_managed_entry(self, *, key: str, collection: str) -> bool: diff --git a/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py index e2c87fb9..3175eb5a 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py @@ -1,3 +1,4 @@ +from datetime import datetime, timezone from types import TracebackType from typing import TYPE_CHECKING, Any, overload @@ -183,7 +184,7 @@ async def _setup(self) -> None: @override async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: """Retrieve a managed entry from DynamoDB.""" - response = await self._connected_client.get_item( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + response = await self._connected_client.get_item( TableName=self._table_name, Key={ "collection": {"S": collection}, @@ -191,15 +192,23 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry }, ) - item = response.get("Item") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + item = response.get("Item") if not item: return None - json_value = item.get("value", {}).get("S") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + json_value = item.get("value", {}).get("S") if not json_value: return None - return ManagedEntry.from_json(json_str=json_value) # pyright: ignore[reportUnknownArgumentType] + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=json_value) + + expires_at_epoch = item.get("ttl", {}).get("N") + + # Our managed entry may carry a TTL, but the TTL in DynamoDB takes precedence. + if expires_at_epoch: + managed_entry.expires_at = datetime.fromtimestamp(int(expires_at_epoch), tz=timezone.utc) + + return managed_entry @override async def _put_managed_entry( @@ -210,7 +219,7 @@ async def _put_managed_entry( managed_entry: ManagedEntry, ) -> None: """Store a managed entry in DynamoDB.""" - json_value = managed_entry.to_json() + json_value = self._serialization_adapter.dump_json(entry=managed_entry) item: dict[str, Any] = { "collection": {"S": collection}, @@ -219,9 +228,9 @@ async def _put_managed_entry( } # Add TTL if present - if managed_entry.ttl is not None and managed_entry.created_at is not None: + if managed_entry.expires_at is not None: # DynamoDB TTL expects a Unix timestamp - ttl_timestamp = int(managed_entry.created_at.timestamp() + managed_entry.ttl) + ttl_timestamp = int(managed_entry.expires_at.timestamp()) item["ttl"] = {"N": str(ttl_timestamp)} await self._connected_client.put_item( # pyright: ignore[reportUnknownMemberType] diff --git a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py index ad58aca0..fbe7bfc2 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py @@ -6,14 +6,15 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ( ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, NUMBERS, sanitize_string, ) -from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str +from key_value.shared.utils.serialization import SerializationAdapter +from key_value.shared.utils.time_to_live import now_as_epoch from typing_extensions import override from key_value.aio.stores.base import ( @@ -84,52 +85,50 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." -def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]: - document: dict[str, Any] = {"collection": collection, "key": key, "value": {}} +class ElasticsearchSerializationAdapter(SerializationAdapter): + """Adapter for Elasticsearch with support for native and string storage modes.""" - # Store in appropriate field based on mode - if native_storage: - document["value"]["flattened"] = managed_entry.value_as_dict - else: - document["value"]["string"] = managed_entry.value_as_json + _native_storage: bool - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at.isoformat() - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at.isoformat() + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the Elasticsearch adapter. - return document + Args: + native_storage: If True (default), store values as flattened dicts. + If False, store values as JSON strings. + """ + super().__init__() + self._native_storage = native_storage + self._date_format = "isoformat" + self._value_format = "dict" if native_storage else "string" -def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry: - value: dict[str, Any] = {} + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") + + data["value"] = {} - raw_value = source.get("value") + if self._native_storage: + data["value"]["flattened"] = value + else: + data["value"]["string"] = value - # Try flattened field first, fall back to string field - if not raw_value or not isinstance(raw_value, dict): - msg = "Value field not found or invalid type" - raise DeserializationError(msg) + return data - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - value = verify_dict(obj=value_flattened) - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - if not isinstance(value_str, str): - msg = "Value in `value` field is not a string" - raise DeserializationError(msg) - value = load_from_json(value_str) - else: - msg = "Value field not found or invalid type" - raise DeserializationError(msg) + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at")) + if "flattened" in value: + data["value"] = value["flattened"] + elif "string" in value: + data["value"] = value["string"] + else: + msg = "Value field not found in Elasticsearch document" + raise DeserializationError(message=msg) - return ManagedEntry( - value=value, - created_at=created_at, - expires_at=expires_at, - ) + return data class ElasticsearchStore( @@ -145,6 +144,8 @@ class ElasticsearchStore( _native_storage: bool + _adapter: SerializationAdapter + @overload def __init__( self, @@ -208,6 +209,7 @@ def __init__( self._index_prefix = index_prefix self._native_storage = native_storage self._is_serverless = False + self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -260,7 +262,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None try: - return source_to_managed_entry(source=source) + return self._adapter.load_dict(data=source) except DeserializationError: return None @@ -293,7 +295,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> continue try: - entries_by_id[doc_id] = source_to_managed_entry(source=source) + entries_by_id[doc_id] = self._adapter.load_dict(data=source) except DeserializationError as e: logger.error( "Failed to deserialize Elasticsearch document in batch operation", @@ -324,9 +326,7 @@ async def _put_managed_entry( index_name: str = self._sanitize_index_name(collection=collection) document_id: str = self._sanitize_document_id(key=key) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) try: _ = await self._client.index( @@ -364,11 +364,10 @@ async def _put_managed_entries( index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) operations.extend([index_action, document]) + try: _ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType] except ElasticsearchSerializationError as e: diff --git a/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py b/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py index cec5b268..3967d3d5 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py @@ -79,7 +79,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if json_str is None: return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -88,7 +88,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) keyring.set_password(service_name=self._service_name, username=combo_key, password=json_str) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py b/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py index 5197c145..55069d96 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py @@ -65,7 +65,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry json_str: str = raw_value.decode(encoding="utf-8") - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -82,7 +82,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> for raw_value in raw_values: if isinstance(raw_value, (bytes, bytearray)): json_str: str = raw_value.decode(encoding="utf-8") - entries.append(ManagedEntry.from_json(json_str=json_str)) + entries.append(self._serialization_adapter.load_json(json_str=json_str)) else: entries.append(None) @@ -106,7 +106,7 @@ async def _put_managed_entry( else: exptime = max(int(managed_entry.ttl), 1) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) _ = await self._client.set( key=combo_key.encode(encoding="utf-8"), diff --git a/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py b/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py index 07caf85a..05f7ad0c 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py @@ -4,7 +4,8 @@ from typing import Any from key_value.shared.utils.managed_entry import ManagedEntry -from typing_extensions import Self, override +from key_value.shared.utils.serialization import BasicSerializationAdapter +from typing_extensions import override from key_value.aio.stores.base import ( SEED_DATA_TYPE, @@ -29,16 +30,6 @@ class MemoryCacheEntry: expires_at: datetime | None - @classmethod - def from_managed_entry(cls, managed_entry: ManagedEntry) -> Self: - return cls( - json_str=managed_entry.to_json(), - expires_at=managed_entry.expires_at, - ) - - def to_managed_entry(self) -> ManagedEntry: - return ManagedEntry.from_json(json_str=self.json_str) - def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, _now: float) -> float: """Calculate time-to-use for cache entries based on their expiration time.""" @@ -53,8 +44,6 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: # noqa: ARG001 return 1 -DEFAULT_MAX_ENTRIES_PER_COLLECTION = 10000 - DEFAULT_PAGE_SIZE = 10000 PAGE_LIMIT = 10000 @@ -62,30 +51,33 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: # noqa: ARG001 class MemoryCollection: _cache: TLRUCache[str, MemoryCacheEntry] - def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION): + def __init__(self, max_entries: int | None = None): """Initialize a fixed-size in-memory collection. Args: - max_entries: The maximum number of entries per collection. Defaults to 10,000 entries. + max_entries: The maximum number of entries per collection. Defaults to no limit. """ self._cache = TLRUCache[str, MemoryCacheEntry]( - maxsize=max_entries, + maxsize=max_entries if max_entries is not None else sys.maxsize, ttu=_memory_cache_ttu, getsizeof=_memory_cache_getsizeof, ) + self._serialization_adapter = BasicSerializationAdapter() + def get(self, key: str) -> ManagedEntry | None: managed_entry_str: MemoryCacheEntry | None = self._cache.get(key) if managed_entry_str is None: return None - managed_entry: ManagedEntry = managed_entry_str.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str.json_str) return managed_entry def put(self, key: str, value: ManagedEntry) -> None: - self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value) + json_str: str = self._serialization_adapter.dump_json(entry=value) + self._cache[key] = MemoryCacheEntry(json_str=json_str, expires_at=value.expires_at) def delete(self, key: str) -> bool: return self._cache.pop(key, None) is not None @@ -105,21 +97,21 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol def __init__( self, *, - max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, + max_entries_per_collection: int | None = None, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None, ): """Initialize a fixed-size in-memory store. Args: - max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000. + max_entries_per_collection: The maximum number of entries per collection. Defaults to no limit. default_collection: The default collection to use if no collection is provided. seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. Each value must be a mapping (dict) that will be stored as the entry's value. Seeding occurs lazily when each collection is first accessed. """ - self.max_entries_per_collection = max_entries_per_collection + self.max_entries_per_collection = max_entries_per_collection if max_entries_per_collection is not None else sys.maxsize self._cache = {} diff --git a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py index 6db89992..d06d19e1 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py @@ -1,17 +1,18 @@ from collections.abc import Sequence -from datetime import datetime +from datetime import datetime, timezone from typing import Any, overload -from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict +from bson.errors import InvalidDocument +from key_value.shared.errors import DeserializationError, SerializationError +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string -from key_value.shared.utils.time_to_live import timezone +from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import Self, override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore try: - from pymongo import AsyncMongoClient + from pymongo import AsyncMongoClient, UpdateOne from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.database import AsyncDatabase from pymongo.results import DeleteResult # noqa: TC002 @@ -35,95 +36,56 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" -def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry. +class MongoDBSerializationAdapter(SerializationAdapter): + """Adapter for MongoDB with support for native and string storage modes.""" - This function deserializes a MongoDB document (created by `managed_entry_to_document`) back to a - ManagedEntry object. It supports both native BSON storage (dict in value.object field) and legacy - JSON string storage (string in value.string field) for migration support. - - Args: - document: The MongoDB document to convert. - - Returns: - A ManagedEntry object reconstructed from the document. - """ - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - data: dict[str, Any] = {} - - # The Value field is an object with two possible fields: `object` and `string` - # - `object`: The value is a native BSON dict - # - `string`: The value is a JSON string - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): - msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): - msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **data}, stringified_value=True) + _native_storage: bool - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the MongoDB adapter.""" + super().__init__() + self._native_storage = native_storage + self._date_format = "datetime" + self._value_format = "dict" if native_storage else "string" -def managed_entry_to_document(key: str, managed_entry: ManagedEntry, *, native_storage: bool = True) -> dict[str, Any]: - """Convert a ManagedEntry to a MongoDB document for storage. + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - This function serializes a ManagedEntry to a MongoDB document format, including the key and all - metadata (TTL, creation, and expiration timestamps). The value storage format depends on the - native_storage parameter. + data["value"] = {} - Args: - key: The key associated with this entry. - managed_entry: The ManagedEntry to serialize. - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. + if self._native_storage: + data["value"]["object"] = value + else: + data["value"]["string"] = value - Returns: - A MongoDB document dict containing the key, value, and all metadata. - """ - document: dict[str, Any] = {"key": key, "value": {}} + return data - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores. For example, - # Mongo can natively handle datetime objects which other stores cannot, if we don't convert to JSON, - # then using py-key-value with Mongo will return different values than if we used another store. - json_str = managed_entry.value_as_json + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - # Store in appropriate field based on mode - if native_storage: - document["value"]["object"] = managed_entry.value_as_dict - else: - document["value"]["string"] = json_str + if "object" in value: + data["value"] = value["object"] + elif "string" in value: + data["value"] = value["string"] + else: + msg = "Value field not found in MongoDB document" + raise DeserializationError(message=msg) - # Add metadata fields - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at + if date_created := data.get("created_at"): + if not isinstance(date_created, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(message=msg) + data["created_at"] = date_created.replace(tzinfo=timezone.utc) + if date_expires := data.get("expires_at"): + if not isinstance(date_expires, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(message=msg) + data["expires_at"] = date_expires.replace(tzinfo=timezone.utc) - return document + return data class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): @@ -132,7 +94,7 @@ class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, Ba _client: AsyncMongoClient[dict[str, Any]] _db: AsyncDatabase[dict[str, Any]] _collections_by_name: dict[str, AsyncCollection[dict[str, Any]]] - _native_storage: bool + _adapter: SerializationAdapter @overload def __init__( @@ -210,7 +172,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} - self._native_storage = native_storage + self._adapter = MongoDBSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -267,7 +229,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry sanitized_collection = self._sanitize_collection_name(collection=collection) if doc := await self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): - return document_to_managed_entry(document=doc) + try: + return self._adapter.load_dict(data=doc) + except DeserializationError: + return None return None @@ -285,7 +250,10 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> async for doc in cursor: if key := doc.get("key"): - managed_entries_by_key[key] = document_to_managed_entry(document=doc) + try: + managed_entries_by_key[key] = self._adapter.load_dict(data=doc) + except DeserializationError: + managed_entries_by_key[key] = None return [managed_entries_by_key[key] for key in keys] @@ -297,15 +265,26 @@ async def _put_managed_entry( collection: str, managed_entry: ManagedEntry, ) -> None: - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) sanitized_collection = self._sanitize_collection_name(collection=collection) - _ = await self._collections_by_name[sanitized_collection].update_one( - filter={"key": key}, - update={"$set": mongo_doc}, - upsert=True, - ) + try: + # Ensure that the value is serializable to JSON + _ = managed_entry.value_as_json + _ = await self._collections_by_name[sanitized_collection].update_one( + filter={"key": key}, + update={ + "$set": { + "key": key, + **mongo_doc, + } + }, + upsert=True, + ) + except InvalidDocument as e: + msg = f"Failed to update MongoDB document: {e}" + raise SerializationError(message=msg) from e @override async def _put_managed_entries( @@ -323,17 +302,20 @@ async def _put_managed_entries( sanitized_collection = self._sanitize_collection_name(collection=collection) - # Use bulk_write for efficient batch operations - from pymongo import UpdateOne - operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) operations.append( UpdateOne( filter={"key": key}, - update={"$set": mongo_doc}, + update={ + "$set": { + "collection": collection, + "key": key, + **mongo_doc, + } + }, upsert=True, ) ) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py index c2a21efc..13e60655 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py @@ -3,9 +3,11 @@ from typing import Any, overload from urllib.parse import urlparse +from key_value.shared.errors import DeserializationError from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -20,41 +22,11 @@ PAGE_LIMIT = 10000 -def managed_entry_to_json(managed_entry: ManagedEntry) -> str: - """Convert a ManagedEntry to a JSON string for Redis storage. - - This function serializes a ManagedEntry to JSON format including all metadata (TTL, creation, - and expiration timestamps). The serialization is designed to preserve all entry information - for round-trip conversion back to a ManagedEntry. - - Args: - managed_entry: The ManagedEntry to serialize. - - Returns: - A JSON string representation of the ManagedEntry with full metadata. - """ - return managed_entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - -def json_to_managed_entry(json_str: str) -> ManagedEntry: - """Convert a JSON string from Redis storage back to a ManagedEntry. - - This function deserializes a JSON string (created by `managed_entry_to_json`) back to a - ManagedEntry object, preserving all metadata including TTL, creation, and expiration timestamps. - - Args: - json_str: The JSON string to deserialize. - - Returns: - A ManagedEntry object reconstructed from the JSON string. - """ - return ManagedEntry.from_json(json_str=json_str, includes_metadata=True) - - class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" _client: Redis + _adapter: SerializationAdapter @overload def __init__(self, *, client: Redis, default_collection: str | None = None) -> None: ... @@ -111,6 +83,7 @@ def __init__( ) self._stable_api = True + self._adapter = BasicSerializationAdapter(date_format="isoformat", value_format="dict") super().__init__(default_collection=default_collection) @@ -123,9 +96,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not isinstance(redis_response, str): return None - managed_entry: ManagedEntry = json_to_managed_entry(json_str=redis_response) - - return managed_entry + try: + return self._adapter.load_json(json_str=redis_response) + except DeserializationError: + return None @override async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -139,7 +113,10 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> entries: list[ManagedEntry | None] = [] for redis_response in redis_responses: if isinstance(redis_response, str): - entries.append(json_to_managed_entry(json_str=redis_response)) + try: + entries.append(self._adapter.load_json(json_str=redis_response)) + except DeserializationError: + entries.append(None) else: entries.append(None) @@ -155,7 +132,7 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value: str = self._adapter.dump_json(entry=managed_entry) if managed_entry.ttl is not None: # Redis does not support <= 0 TTLs @@ -181,10 +158,10 @@ async def _put_managed_entries( if ttl is None: # If there is no TTL, we can just do a simple mset - mapping: dict[str, str] = { - compound_key(collection=collection, key=key): managed_entry_to_json(managed_entry=managed_entry) - for key, managed_entry in zip(keys, managed_entries, strict=True) - } + mapping: dict[str, str] = {} + for key, managed_entry in zip(keys, managed_entries, strict=True): + json_value = self._adapter.dump_json(entry=managed_entry) + mapping[compound_key(collection=collection, key=key)] = json_value await self._client.mset(mapping=mapping) @@ -198,7 +175,7 @@ async def _put_managed_entries( for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value = self._adapter.dump_json(entry=managed_entry) pipeline.setex(name=combo_key, time=ttl_seconds, value=json_value) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py index 61b98f68..2c0829d5 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py @@ -109,7 +109,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None managed_entry_str: str = value.decode("utf-8") - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str) + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) return managed_entry @@ -124,7 +124,7 @@ async def _put_managed_entry( self._fail_on_closed_store() combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) self._db[combo_key] = json_value.encode("utf-8") @@ -147,7 +147,7 @@ async def _put_managed_entries( batch = WriteBatch() for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) batch.put(combo_key, json_value.encode("utf-8")) self._db.write(batch) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py b/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py index 98ce2df9..b0b155a5 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py @@ -1,10 +1,11 @@ +import sys from collections import defaultdict from dataclasses import dataclass from datetime import datetime from key_value.shared.utils.compound import compound_key, get_collections_from_compound_keys, get_keys_from_compound_keys -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json -from key_value.shared.utils.time_to_live import seconds_to +from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter from typing_extensions import override from key_value.aio.stores.base import ( @@ -24,22 +25,6 @@ class SimpleStoreEntry: created_at: datetime | None expires_at: datetime | None - @property - def current_ttl(self) -> float | None: - if self.expires_at is None: - return None - - return seconds_to(datetime=self.expires_at) - - def to_managed_entry(self) -> ManagedEntry: - managed_entry: ManagedEntry = ManagedEntry( - value=load_from_json(json_str=self.json_str), - expires_at=self.expires_at, - created_at=self.created_at, - ) - - return managed_entry - class SimpleStore(BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyStore, BaseStore): """Simple managed dictionary-based key-value store for testing and development.""" @@ -48,18 +33,20 @@ class SimpleStore(BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDes _data: dict[str, SimpleStoreEntry] - def __init__(self, max_entries: int = DEFAULT_SIMPLE_STORE_MAX_ENTRIES, default_collection: str | None = None): + def __init__(self, max_entries: int | None = None, default_collection: str | None = None): """Initialize the simple store. Args: - max_entries: The maximum number of entries to store. Defaults to 10000. + max_entries: The maximum number of entries to store. Defaults to no limit. default_collection: The default collection to use if no collection is provided. """ - self.max_entries = max_entries + self.max_entries = max_entries if max_entries is not None else sys.maxsize self._data = defaultdict[str, SimpleStoreEntry]() + self._serialization_adapter = BasicSerializationAdapter(date_format=None) + super().__init__(default_collection=default_collection) @override @@ -71,7 +58,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if store_entry is None: return None - return store_entry.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=store_entry.json_str) + managed_entry.expires_at = store_entry.expires_at + managed_entry.created_at = store_entry.created_at + return managed_entry @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -81,7 +71,9 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: _ = self._data.pop(next(iter(self._data))) self._data[combo_key] = SimpleStoreEntry( - json_str=managed_entry.to_json(include_metadata=False), expires_at=managed_entry.expires_at, created_at=managed_entry.created_at + json_str=self._serialization_adapter.dump_json(entry=managed_entry), + expires_at=managed_entry.expires_at, + created_at=managed_entry.created_at, ) @override diff --git a/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py b/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py index aa2530a7..71e02cfc 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py @@ -95,7 +95,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry decoded_response: str = response.decode("utf-8") - return ManagedEntry.from_json(json_str=decoded_response) + return self._serialization_adapter.load_json(json_str=decoded_response) @override async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -110,7 +110,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> for response in responses: if isinstance(response, bytes): decoded_response: str = response.decode("utf-8") - entries.append(ManagedEntry.from_json(json_str=decoded_response)) + entries.append(self._serialization_adapter.load_json(json_str=decoded_response)) else: entries.append(None) @@ -126,7 +126,7 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) expiry: ExpirySet | None = ExpirySet(expiry_type=ExpiryType.SEC, value=int(managed_entry.ttl)) if managed_entry.ttl else None diff --git a/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py b/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py index cfc86afe..009e91b2 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py @@ -86,7 +86,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry combo_key: str = compound_key(collection=collection, key=key) try: - response = self._kv_v2.read_secret(path=combo_key, mount_point=self._mount_point) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + response = self._kv_v2.read_secret(path=combo_key, mount_point=self._mount_point, raise_on_deleted_version=True) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] except InvalidPath: return None except Exception: @@ -102,13 +102,13 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None json_str: str = secret_data["value"] # pyright: ignore[reportUnknownVariableType] - return ManagedEntry.from_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] + return self._serialization_adapter.load_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) # Store the JSON string in a 'value' field secret_data = {"value": json_str} diff --git a/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py b/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py index 1ce3dcc5..1a3c4add 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py @@ -88,14 +88,14 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not (json_str := get_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key)): return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: sanitized_key = self._sanitize_key(key=key) registry_path = self._get_registry_path(collection=collection) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) set_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key, value=json_str) diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py index 8db9f799..83b58ff8 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, SupportsFloat -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.aio.protocols.key_value import AsyncKeyValue @@ -56,7 +56,7 @@ def _should_compress(self, value: dict[str, Any]) -> bool: return False # Check size - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) return item_size >= self.min_size_to_compress def _compress_value(self, value: dict[str, Any]) -> dict[str, Any]: diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py index c537f949..9cea5d0c 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py @@ -2,7 +2,7 @@ from typing import Any, SupportsFloat from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.aio.protocols.key_value import AsyncKeyValue @@ -65,7 +65,7 @@ def _within_size_limit(self, value: dict[str, Any], *, collection: str | None = EntryTooLargeError: If raise_on_too_large is True and the value exceeds max_size. """ - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) if self.min_size is not None and item_size < self.min_size: if self.raise_on_too_small: diff --git a/key-value/key-value-aio/tests/stores/base.py b/key-value/key-value-aio/tests/stores/base.py index 2ef67ff4..43177ee2 100644 --- a/key-value/key-value-aio/tests/stores/base.py +++ b/key-value/key-value-aio/tests/stores/base.py @@ -174,7 +174,7 @@ async def test_put_ttl_get_ttl(self, store: BaseStore): assert value == {"test": "test"} assert ttl is not None - assert ttl == IsFloat(approx=100) + assert ttl == IsFloat(approx=100, delta=2), f"TTL should be ~100, but is {ttl}" async def test_negative_ttl(self, store: BaseStore): """Tests that a negative ttl will return None when getting the key.""" diff --git a/key-value/key-value-aio/tests/stores/disk/test_disk.py b/key-value/key-value-aio/tests/stores/disk/test_disk.py index 2aaf7ceb..e2c3453c 100644 --- a/key-value/key-value-aio/tests/stores/disk/test_disk.py +++ b/key-value/key-value-aio/tests/stores/disk/test_disk.py @@ -1,7 +1,11 @@ +import json import tempfile from collections.abc import AsyncGenerator import pytest +from dirty_equals import IsDatetime +from diskcache.core import Cache +from inline_snapshot import snapshot from typing_extensions import override from key_value.aio.stores.disk import DiskStore @@ -22,3 +26,22 @@ async def store(self, disk_store: DiskStore) -> DiskStore: disk_store._cache.clear() # pyright: ignore[reportPrivateUsage] return disk_store + + @pytest.fixture + async def disk_cache(self, disk_store: DiskStore) -> Cache: + return disk_store._cache # pyright: ignore[reportPrivateUsage] + + async def test_value_stored(self, store: DiskStore, disk_cache: Cache): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py b/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py index 09d226e3..832eeccb 100644 --- a/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py +++ b/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py @@ -1,13 +1,20 @@ +import json import tempfile from collections.abc import AsyncGenerator from pathlib import Path +from typing import TYPE_CHECKING import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from typing_extensions import override from key_value.aio.stores.disk.multi_store import MultiDiskStore from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin +if TYPE_CHECKING: + from diskcache.core import Cache + TEST_SIZE_LIMIT = 100 * 1024 # 100KB @@ -24,3 +31,24 @@ async def store(self, multi_disk_store: MultiDiskStore) -> MultiDiskStore: multi_disk_store._cache[collection].clear() # pyright: ignore[reportPrivateUsage] return multi_disk_store + + async def test_value_stored(self, store: MultiDiskStore): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + disk_cache: Cache = store._cache["test"] # pyright: ignore[reportPrivateUsage] + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + { + "value": {"name": "Alice", "age": 30}, + "created_at": IsDatetime(iso_string=True), + } + ) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py b/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py index e0af2a47..d29733e2 100644 --- a/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py +++ b/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py @@ -1,8 +1,15 @@ import contextlib +import json from collections.abc import AsyncGenerator +from datetime import datetime, timezone +from typing import Any import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true +from types_aiobotocore_dynamodb.client import DynamoDBClient +from types_aiobotocore_dynamodb.type_defs import GetItemOutputTypeDef from typing_extensions import override from key_value.aio.stores.base import BaseStore @@ -48,7 +55,16 @@ class DynamoDBFailedToStartError(Exception): pass +def get_value_from_response(response: GetItemOutputTypeDef) -> dict[str, Any]: + return json.loads(response.get("Item", {}).get("value", {}).get("S", {})) # pyright: ignore[reportArgumentType] + + +def get_dynamo_client_from_store(store: DynamoDBStore) -> DynamoDBClient: + return store._connected_client # pyright: ignore[reportPrivateUsage] + + @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestDynamoDBStore(ContextManagerStoreTestMixin, BaseStoreTests): @pytest.fixture(autouse=True, scope="session", params=DYNAMODB_VERSIONS_TO_TEST) async def setup_dynamodb(self, request: pytest.FixtureRequest) -> AsyncGenerator[None, None]: @@ -101,3 +117,30 @@ async def dynamodb_store(self, store: DynamoDBStore) -> DynamoDBStore: @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + async def test_value_stored(self, store: DynamoDBStore): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + response = await get_dynamo_client_from_store(store=store).get_item( + TableName=DYNAMODB_TEST_TABLE, Key={"collection": {"S": "test"}, "key": {"S": "test_key"}} + ) + assert get_value_from_response(response=response) == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}} + ) + + assert "ttl" not in response.get("Item", {}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + response = await get_dynamo_client_from_store(store=store).get_item( + TableName=DYNAMODB_TEST_TABLE, Key={"collection": {"S": "test"}, "key": {"S": "test_key"}} + ) + assert get_value_from_response(response=response) == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) + # Verify DynamoDB TTL attribute is set for automatic expiration + assert "ttl" in response.get("Item", {}), "DynamoDB TTL attribute should be set when ttl parameter is provided" + ttl_value = int(response["Item"]["ttl"]["N"]) # pyright: ignore[reportTypedDictNotRequiredAccess] + now = datetime.now(timezone.utc) + assert ttl_value > now.timestamp(), "TTL timestamp should be a positive integer" + assert ttl_value < now.timestamp() + 10, "TTL timestamp should be less than the expected expiration time" diff --git a/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py b/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py index 079d8a23..ec217a5d 100644 --- a/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py +++ b/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py @@ -11,7 +11,7 @@ from key_value.aio.stores.base import BaseStore from key_value.aio.stores.elasticsearch import ElasticsearchStore -from key_value.aio.stores.elasticsearch.store import managed_entry_to_document, source_to_managed_entry +from key_value.aio.stores.elasticsearch.store import ElasticsearchSerializationAdapter from tests.conftest import docker_container, should_skip_docker_tests from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -55,19 +55,18 @@ def test_managed_entry_document_conversion(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry) + adapter = ElasticsearchSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "collection": "test_collection", - "key": "test_key", "value": {"string": '{"test": "test"}'}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", } ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -80,19 +79,18 @@ def test_managed_entry_document_conversion_native_storage(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry, native_storage=True) + adapter = ElasticsearchSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "collection": "test_collection", - "key": "test_key", "value": {"flattened": {"test": "test"}}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", } ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -170,8 +168,6 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"flattened": {"name": "Alice", "age": 30}}, "created_at": IsStr(min_length=20, max_length=40), } @@ -182,8 +178,6 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"flattened": {"name": "Bob", "age": 25}}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), @@ -231,8 +225,6 @@ async def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_c response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"string": '{"age": 30, "name": "Alice"}'}, "created_at": IsStr(min_length=20, max_length=40), } @@ -243,8 +235,6 @@ async def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_c response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"string": '{"age": 25, "name": "Bob"}'}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), diff --git a/key-value/key-value-aio/tests/stores/memcached/test_memcached.py b/key-value/key-value-aio/tests/stores/memcached/test_memcached.py index d8089848..661dfa80 100644 --- a/key-value/key-value-aio/tests/stores/memcached/test_memcached.py +++ b/key-value/key-value-aio/tests/stores/memcached/test_memcached.py @@ -1,8 +1,11 @@ import contextlib +import json from collections.abc import AsyncGenerator import pytest from aiomcache import Client +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true from typing_extensions import override @@ -42,6 +45,7 @@ class MemcachedFailedToStartError(Exception): @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestMemcachedStore(ContextManagerStoreTestMixin, BaseStoreTests): @pytest.fixture(autouse=True, scope="session", params=MEMCACHED_VERSIONS_TO_TEST) async def setup_memcached(self, request: pytest.FixtureRequest) -> AsyncGenerator[None, None]: @@ -64,3 +68,28 @@ async def store(self, setup_memcached: None) -> MemcachedStore: @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + @pytest.fixture + async def memcached_client(self, store: MemcachedStore) -> Client: + return store._client # pyright: ignore[reportPrivateUsage] + + async def test_value_stored(self, store: MemcachedStore, memcached_client: Client): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = await memcached_client.get(key=b"test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = await memcached_client.get(key=b"test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"age": 30, "name": "Alice"}, + } + ) diff --git a/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py b/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py index a638a152..19046e4c 100644 --- a/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py @@ -14,7 +14,7 @@ from key_value.aio.stores.base import BaseStore from key_value.aio.stores.mongodb import MongoDBStore -from key_value.aio.stores.mongodb.store import document_to_managed_entry, managed_entry_to_document +from key_value.aio.stores.mongodb.store import MongoDBSerializationAdapter from tests.conftest import docker_container, should_skip_docker_tests from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -50,18 +50,19 @@ def test_managed_entry_document_conversion_native_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) + + adapter = MongoDBSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "key": "test", "value": {"object": {"test": "test"}}, "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -74,18 +75,18 @@ def test_managed_entry_document_conversion_legacy_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) + adapter = MongoDBSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "key": "test", "value": {"string": '{"test": "test"}'}, "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -98,6 +99,7 @@ async def clean_mongodb_database(store: MongoDBStore) -> None: _ = await store._client.drop_database(name_or_database=store._db.name) # pyright: ignore[reportPrivateUsage] +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class BaseMongoDBStoreTests(ContextManagerStoreTestMixin, BaseStoreTests): """Base class for MongoDB store tests.""" diff --git a/key-value/key-value-aio/tests/stores/redis/test_redis.py b/key-value/key-value-aio/tests/stores/redis/test_redis.py index 407fc799..f75e228a 100644 --- a/key-value/key-value-aio/tests/stores/redis/test_redis.py +++ b/key-value/key-value-aio/tests/stores/redis/test_redis.py @@ -1,17 +1,16 @@ +import json from collections.abc import AsyncGenerator -from datetime import datetime, timedelta, timezone +from typing import Any import pytest -from dirty_equals import IsFloat +from dirty_equals import IsDatetime from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true -from key_value.shared.utils.managed_entry import ManagedEntry from redis.asyncio.client import Redis from typing_extensions import override from key_value.aio.stores.base import BaseStore from key_value.aio.stores.redis import RedisStore -from key_value.aio.stores.redis.store import json_to_managed_entry, managed_entry_to_json from tests.conftest import docker_container, should_skip_docker_tests from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -28,25 +27,6 @@ ] -def test_managed_entry_document_conversion(): - created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) - expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_json(managed_entry=managed_entry) - - assert document == snapshot( - '{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}' - ) - - round_trip_managed_entry = json_to_managed_entry(json_str=document) - - assert round_trip_managed_entry.value == managed_entry.value - assert round_trip_managed_entry.created_at == created_at - assert round_trip_managed_entry.ttl == IsFloat(lt=0) - assert round_trip_managed_entry.expires_at == expires_at - - async def ping_redis() -> bool: client: Redis = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) try: @@ -59,6 +39,10 @@ class RedisFailedToStartError(Exception): pass +def get_client_from_store(store: RedisStore) -> Redis: + return store._client # pyright: ignore[reportPrivateUsage] + + @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestRedisStore(ContextManagerStoreTestMixin, BaseStoreTests): @pytest.fixture(autouse=True, scope="session", params=REDIS_VERSIONS_TO_TEST) @@ -78,14 +62,18 @@ async def store(self, setup_redis: RedisStore) -> RedisStore: """Create a Redis store for testing.""" # Create the store with test database redis_store = RedisStore(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) - _ = await redis_store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = await get_client_from_store(store=redis_store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] return redis_store + @pytest.fixture + def redis_client(self, store: RedisStore) -> Redis: + return get_client_from_store(store=store) + async def test_redis_url_connection(self): """Test Redis store creation with URL.""" redis_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" store = RedisStore(url=redis_url) - _ = await store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = await get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] await store.put(collection="test", key="url_test", value={"test": "value"}) result = await store.get(collection="test", key="url_test") assert result == {"test": "value"} @@ -97,11 +85,62 @@ async def test_redis_client_connection(self): client = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) store = RedisStore(client=client) - _ = await store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = await get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] await store.put(collection="test", key="client_test", value={"test": "value"}) result = await store.get(collection="test", key="client_test") assert result == {"test": "value"} + async def test_redis_document_format(self, store: RedisStore, redis_client: Redis): + """Test Redis store document format.""" + await store.put(collection="test", key="document_format_test_1", value={"test_1": "value_1"}) + await store.put(collection="test", key="document_format_test_2", value={"test_2": "value_2"}, ttl=10) + + raw_documents: Any = await redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) + raw_documents_dicts: list[dict[str, Any]] = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + { + "created_at": IsDatetime(iso_string=True), + "value": {"test_1": "value_1"}, + }, + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"test_2": "value_2"}, + }, + ] + ) + + await store.put_many( + collection="test", + keys=["document_format_test_3", "document_format_test_4"], + values=[{"test_3": "value_3"}, {"test_4": "value_4"}], + ttl=10, + ) + raw_documents = await redis_client.mget(keys=["test::document_format_test_3", "test::document_format_test_4"]) + raw_documents_dicts = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"test_3": "value_3"}, + }, + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"test_4": "value_4"}, + }, + ] + ) + + await store.put(collection="test", key="document_format_test", value={"test": "value"}, ttl=10) + raw_document: Any = await redis_client.get(name="test::document_format_test") + raw_document_dict = json.loads(raw_document) + assert raw_document_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test": "value"}} + ) + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py b/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py index e3f37694..b5c60336 100644 --- a/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py +++ b/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py @@ -1,8 +1,12 @@ +import json from collections.abc import AsyncGenerator from pathlib import Path from tempfile import TemporaryDirectory import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot +from rocksdict import Rdict from typing_extensions import override from key_value.aio.stores.base import BaseStore @@ -10,6 +14,7 @@ from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestRocksDBStore(ContextManagerStoreTestMixin, BaseStoreTests): @override @pytest.fixture @@ -59,3 +64,24 @@ async def test_rocksdb_db_connection(self): @pytest.mark.skip(reason="Local disk stores are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + @pytest.fixture + async def rocksdb_client(self, store: RocksDBStore) -> Rdict: + return store._db # pyright: ignore[reportPrivateUsage] + + async def test_value_stored(self, store: RocksDBStore, rocksdb_client: Rdict): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/simple/test_store.py b/key-value/key-value-aio/tests/stores/simple/test_store.py index 1ac0341f..f48522b1 100644 --- a/key-value/key-value-aio/tests/stores/simple/test_store.py +++ b/key-value/key-value-aio/tests/stores/simple/test_store.py @@ -5,6 +5,7 @@ from tests.stores.base import BaseStoreTests +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestSimpleStore(BaseStoreTests): @override @pytest.fixture diff --git a/key-value/key-value-aio/tests/stores/valkey/test_valkey.py b/key-value/key-value-aio/tests/stores/valkey/test_valkey.py index 9dc6895b..27fe7930 100644 --- a/key-value/key-value-aio/tests/stores/valkey/test_valkey.py +++ b/key-value/key-value-aio/tests/stores/valkey/test_valkey.py @@ -1,7 +1,10 @@ import contextlib +import json from collections.abc import AsyncGenerator import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true from typing_extensions import override @@ -84,3 +87,26 @@ async def store(self, setup_valkey: None): @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + async def test_value_stored(self, store: BaseStore): + from key_value.aio.stores.valkey import ValkeyStore + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + assert isinstance(store, ValkeyStore) + + valkey_client = store._connected_client # pyright: ignore[reportPrivateUsage] + assert valkey_client is not None + value = await valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = await valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/vault/test_vault.py b/key-value/key-value-aio/tests/stores/vault/test_vault.py index 81b0cc4d..e9704641 100644 --- a/key-value/key-value-aio/tests/stores/vault/test_vault.py +++ b/key-value/key-value-aio/tests/stores/vault/test_vault.py @@ -30,6 +30,7 @@ class VaultFailedToStartError(Exception): @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestVaultStore(BaseStoreTests): def get_vault_client(self): import hvac diff --git a/key-value/key-value-aio/tests/stores/windows_registry/test_windows_registry.py b/key-value/key-value-aio/tests/stores/windows_registry/test_windows_registry.py index 235ff868..d2b1e47f 100644 --- a/key-value/key-value-aio/tests/stores/windows_registry/test_windows_registry.py +++ b/key-value/key-value-aio/tests/stores/windows_registry/test_windows_registry.py @@ -15,6 +15,7 @@ @pytest.mark.skipif(condition=not detect_on_windows(), reason="WindowsRegistryStore is only available on Windows") +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestWindowsRegistryStore(BaseStoreTests): def cleanup(self): from winreg import HKEY_CURRENT_USER diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py b/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py index 8585ceb0..cd4b8ec0 100644 --- a/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py +++ b/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py @@ -1,5 +1,6 @@ import pytest from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.aio.stores.memory.store import MemoryStore @@ -137,12 +138,8 @@ async def test_put_many_all_too_large_without_raise(self, memory_store: MemorySt async def test_exact_size_limit(self, memory_store: MemoryStore): # First, determine the exact size of a small value - from key_value.shared.utils.managed_entry import ManagedEntry - test_value = {"test": "value"} - managed_entry = ManagedEntry(value=test_value) - json_str = managed_entry.to_json() - exact_size = len(json_str.encode("utf-8")) + exact_size = estimate_serialized_size(value=test_value) # Create a store with exact size limit limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size, raise_on_too_large=True) diff --git a/key-value/key-value-shared/pyproject.toml b/key-value/key-value-shared/pyproject.toml index f4054e96..67d570fe 100644 --- a/key-value/key-value-shared/pyproject.toml +++ b/key-value/key-value-shared/pyproject.toml @@ -56,8 +56,3 @@ extend="../../pyproject.toml" [tool.pyright] extends = "../../pyproject.toml" - -executionEnvironments = [ - { root = "tests", reportPrivateUsage = false, extraPaths = ["src"]}, - { root = "src" } -] diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index 60cafb60..42a08b06 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -2,12 +2,13 @@ from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime -from typing import Any, SupportsFloat, cast +from typing import Any, SupportsFloat from typing_extensions import Self from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.time_to_live import now, now_plus, prepare_ttl, try_parse_datetime_str +from key_value.shared.type_checking.bear_spray import bear_enforce +from key_value.shared.utils.time_to_live import now, now_plus, seconds_to @dataclass(kw_only=True) @@ -22,22 +23,20 @@ class ManagedEntry: value: Mapping[str, Any] created_at: datetime | None = field(default=None) - ttl: float | None = field(default=None) expires_at: datetime | None = field(default=None) - def __post_init__(self) -> None: - if self.ttl is not None and self.expires_at is None: - self.expires_at = now_plus(seconds=self.ttl) - - elif self.expires_at is not None and self.ttl is None: - self.recalculate_ttl() - @property def is_expired(self) -> bool: if self.expires_at is None: return False return self.expires_at <= now() + @property + def ttl(self) -> float | None: + if self.expires_at is None: + return None + return seconds_to(datetime=self.expires_at) + @property def value_as_json(self) -> str: """Return the value as a JSON string.""" @@ -45,102 +44,26 @@ def value_as_json(self) -> str: @property def value_as_dict(self) -> dict[str, Any]: - return dict(self.value) - - def recalculate_ttl(self) -> None: - if self.expires_at is not None and self.ttl is None: - self.ttl = (self.expires_at - now()).total_seconds() - - def to_dict( - self, include_metadata: bool = True, include_expiration: bool = True, include_creation: bool = True, stringify_value: bool = False - ) -> dict[str, Any]: - if not include_metadata: - return dict(self.value) - - data: dict[str, Any] = {"value": self.value_as_json if stringify_value else self.value} - - if include_creation and self.created_at: - data["created_at"] = self.created_at.isoformat() - if include_expiration and self.expires_at: - data["expires_at"] = self.expires_at.isoformat() - - return data - - def to_json( - self, include_metadata: bool = True, include_expiration: bool = True, include_creation: bool = True, stringify_value: bool = False - ) -> str: - return dump_to_json( - obj=self.to_dict( - include_metadata=include_metadata, - include_expiration=include_expiration, - include_creation=include_creation, - stringify_value=stringify_value, - ) - ) + return verify_dict(obj=self.value) - @classmethod - def from_dict( # noqa: PLR0912 - cls, data: dict[str, Any], includes_metadata: bool = True, ttl: SupportsFloat | None = None, stringified_value: bool = False - ) -> Self: - if not includes_metadata: - return cls( - value=data, - ) - - created_at: datetime | None = None - expires_at: datetime | None = None - - if created_at_value := data.get("created_at"): - if isinstance(created_at_value, str): - created_at = try_parse_datetime_str(value=created_at_value) - elif isinstance(created_at_value, datetime): - created_at = created_at_value - else: - msg = "Expected `created_at` field to be a string or datetime" - raise DeserializationError(msg) - - if expires_at_value := data.get("expires_at"): - if isinstance(expires_at_value, str): - expires_at = try_parse_datetime_str(value=expires_at_value) - elif isinstance(expires_at_value, datetime): - expires_at = expires_at_value - else: - msg = "Expected `expires_at` field to be a string or datetime" - raise DeserializationError(msg) - - if not (raw_value := data.get("value")): - msg = "Value is None" - raise DeserializationError(msg) - - value: dict[str, Any] - - if stringified_value: - if not isinstance(raw_value, str): - msg = "Value is not a string" - raise DeserializationError(msg) - value = load_from_json(json_str=raw_value) - else: - if not isinstance(raw_value, dict): - msg = "Value is not a dictionary" - raise DeserializationError(msg) - value = verify_dict(obj=raw_value) - - ttl_seconds: float | None = prepare_ttl(t=ttl) + @property + def created_at_isoformat(self) -> str | None: + return self.created_at.isoformat() if self.created_at else None + @property + def expires_at_isoformat(self) -> str | None: + return self.expires_at.isoformat() if self.expires_at else None + + @classmethod + def from_ttl(cls, *, value: Mapping[str, Any], created_at: datetime | None = None, ttl: SupportsFloat) -> Self: return cls( - created_at=created_at, - expires_at=expires_at, - ttl=ttl_seconds, value=value, + created_at=created_at, + expires_at=(now_plus(seconds=float(ttl)) if ttl else None), ) - @classmethod - def from_json(cls, json_str: str, includes_metadata: bool = True, ttl: SupportsFloat | None = None) -> Self: - data: dict[str, Any] = load_from_json(json_str=json_str) - - return cls.from_dict(data=data, includes_metadata=includes_metadata, ttl=ttl) - +@bear_enforce def dump_to_json(obj: dict[str, Any]) -> str: try: return json.dumps(obj, sort_keys=True) @@ -149,6 +72,7 @@ def dump_to_json(obj: dict[str, Any]) -> str: raise SerializationError(msg) from e +@bear_enforce def load_from_json(json_str: str) -> dict[str, Any]: try: return verify_dict(obj=json.loads(json_str)) # pyright: ignore[reportAny] @@ -158,13 +82,30 @@ def load_from_json(json_str: str) -> dict[str, Any]: raise DeserializationError(msg) from e +@bear_enforce def verify_dict(obj: Any) -> dict[str, Any]: - if not isinstance(obj, dict): - msg = "Object is not a dictionary" - raise DeserializationError(msg) + if not isinstance(obj, Mapping): + msg = "Object is not a Mapping" + raise TypeError(msg) if not all(isinstance(key, str) for key in obj): # pyright: ignore[reportUnknownVariableType] msg = "Object contains non-string keys" - raise DeserializationError(msg) + raise TypeError(msg) + + return dict(obj) # pyright: ignore[reportUnknownArgumentType] - return cast(typ="dict[str, Any]", val=obj) + +def estimate_serialized_size(value: Mapping[str, Any]) -> int: + """Estimate the serialized size of a value without creating a ManagedEntry. + + This function provides a more efficient way to estimate the size of a value + when serialized to JSON, without the overhead of creating a full ManagedEntry object. + This is useful for size-based checks in wrappers. + + Args: + value: The value mapping to estimate the size for. + + Returns: + The estimated size in bytes when serialized to JSON. + """ + return len(dump_to_json(obj=dict(value))) diff --git a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py new file mode 100644 index 00000000..1350513b --- /dev/null +++ b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py @@ -0,0 +1,151 @@ +"""Serialization adapter base class for converting ManagedEntry objects to/from store-specific formats. + +This module provides the SerializationAdapter ABC that store implementations should use +to define their own serialization strategy. Store-specific adapter implementations +should be defined within their respective store modules. +""" + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Literal, TypeVar + +from key_value.shared.errors import DeserializationError, SerializationError +from key_value.shared.utils.managed_entry import ManagedEntry, dump_to_json, load_from_json, verify_dict + +T = TypeVar("T") + + +def key_must_be(dictionary: dict[str, Any], /, key: str, expected_type: type[T]) -> T | None: + if key not in dictionary: + return None + if not isinstance(dictionary[key], expected_type): + msg = f"{key} must be a {expected_type.__name__}" + raise TypeError(msg) + return dictionary[key] + + +def parse_datetime_str(value: str) -> datetime: + try: + return datetime.fromisoformat(value) + except ValueError: + msg = f"Invalid datetime string: {value}" + raise DeserializationError(message=msg) from None + + +class SerializationAdapter(ABC): + """Base class for store-specific serialization adapters. + + Adapters encapsulate the logic for converting between ManagedEntry objects + and store-specific storage formats. This provides a consistent interface + while allowing each store to optimize its serialization strategy. + """ + + _date_format: Literal["isoformat", "datetime"] | None = "isoformat" + _value_format: Literal["string", "dict"] | None = "dict" + + def __init__( + self, *, date_format: Literal["isoformat", "datetime"] | None = "isoformat", value_format: Literal["string", "dict"] | None = "dict" + ) -> None: + self._date_format = date_format + self._value_format = value_format + + def load_json(self, json_str: str) -> ManagedEntry: + """Convert a JSON string to a ManagedEntry.""" + loaded_data: dict[str, Any] = load_from_json(json_str=json_str) + + return self.load_dict(data=loaded_data) + + @abstractmethod + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + """Prepare data for loading. + + This method is used by subclasses to handle any required transformations before loading the data into a ManagedEntry.""" + + def load_dict(self, data: dict[str, Any]) -> ManagedEntry: + """Convert a dictionary to a ManagedEntry.""" + + data = self.prepare_load(data=data) + + managed_entry_proto: dict[str, Any] = {} + + if self._date_format == "isoformat": + if created_at := key_must_be(data, key="created_at", expected_type=str): + managed_entry_proto["created_at"] = parse_datetime_str(value=created_at) + if expires_at := key_must_be(data, key="expires_at", expected_type=str): + managed_entry_proto["expires_at"] = parse_datetime_str(value=expires_at) + + if self._date_format == "datetime": + if created_at := key_must_be(data, key="created_at", expected_type=datetime): + managed_entry_proto["created_at"] = created_at + if expires_at := key_must_be(data, key="expires_at", expected_type=datetime): + managed_entry_proto["expires_at"] = expires_at + + if "value" not in data: + msg = "Value field not found" + raise DeserializationError(message=msg) + + value = data["value"] + + managed_entry_value: dict[str, Any] = {} + + if isinstance(value, str): + managed_entry_value = load_from_json(json_str=value) + elif isinstance(value, dict): + managed_entry_value = verify_dict(obj=value) + else: + msg = "Value field is not a string or dictionary" + raise DeserializationError(message=msg) + + return ManagedEntry( + value=managed_entry_value, + created_at=managed_entry_proto.get("created_at"), + expires_at=managed_entry_proto.get("expires_at"), + ) + + @abstractmethod + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + """Prepare data for dumping to a dictionary. + + This method is used by subclasses to handle any required transformations before dumping the data to a dictionary.""" + + def dump_dict(self, entry: ManagedEntry, exclude_none: bool = True) -> dict[str, Any]: + """Convert a ManagedEntry to a dictionary.""" + + data: dict[str, Any] = { + "value": entry.value_as_dict if self._value_format == "dict" else entry.value_as_json, + } + + if self._date_format == "isoformat": + data["created_at"] = entry.created_at_isoformat + data["expires_at"] = entry.expires_at_isoformat + + if self._date_format == "datetime": + data["created_at"] = entry.created_at + data["expires_at"] = entry.expires_at + + if exclude_none: + data = {k: v for k, v in data.items() if v is not None} + + return self.prepare_dump(data=data) + + def dump_json(self, entry: ManagedEntry, exclude_none: bool = True) -> str: + """Convert a ManagedEntry to a JSON string.""" + if self._date_format == "datetime": + msg = 'dump_json is incompatible with date_format="datetime"; use date_format="isoformat" or dump_dict().' + raise SerializationError(msg) + return dump_to_json(obj=self.dump_dict(entry=entry, exclude_none=exclude_none)) + + +class BasicSerializationAdapter(SerializationAdapter): + """Basic serialization adapter that does not perform any transformations.""" + + def __init__( + self, *, date_format: Literal["isoformat", "datetime"] | None = "isoformat", value_format: Literal["string", "dict"] | None = "dict" + ) -> None: + super().__init__(date_format=date_format, value_format=value_format) + + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + return data + + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + return data diff --git a/key-value/key-value-shared/tests/utils/test_serialization.py b/key-value/key-value-shared/tests/utils/test_serialization.py new file mode 100644 index 00000000..849b818a --- /dev/null +++ b/key-value/key-value-shared/tests/utils/test_serialization.py @@ -0,0 +1,80 @@ +from datetime import datetime, timedelta, timezone + +import pytest +from inline_snapshot import snapshot + +from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter + +FIXED_DATETIME_ONE = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) +FIXED_DATETIME_ONE_ISOFORMAT = FIXED_DATETIME_ONE.isoformat() +FIXED_DATETIME_ONE_PLUS_10_SECONDS = FIXED_DATETIME_ONE + timedelta(seconds=10) +FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT = FIXED_DATETIME_ONE_PLUS_10_SECONDS.isoformat() + +FIXED_DATETIME_TWO = datetime(2025, 1, 1, 0, 0, 1, tzinfo=timezone.utc) +FIXED_DATETIME_TWO_PLUS_10_SECONDS = FIXED_DATETIME_TWO + timedelta(seconds=10) +FIXED_DATETIME_TWO_ISOFORMAT = FIXED_DATETIME_TWO.isoformat() +FIXED_DATETIME_TWO_PLUS_10_SECONDS_ISOFORMAT = FIXED_DATETIME_TWO_PLUS_10_SECONDS.isoformat() + +TEST_DATA_ONE = {"key_one": "value_one", "key_two": "value_two", "key_three": {"nested_key": "nested_value"}} +TEST_ENTRY_ONE = ManagedEntry(value=TEST_DATA_ONE, created_at=FIXED_DATETIME_ONE, expires_at=FIXED_DATETIME_ONE_PLUS_10_SECONDS) +TEST_DATA_TWO = {"key_one": ["value_one", "value_two", "value_three"], "key_two": 123, "key_three": {"nested_key": "nested_value"}} +TEST_ENTRY_TWO = ManagedEntry(value=TEST_DATA_TWO, created_at=FIXED_DATETIME_TWO, expires_at=FIXED_DATETIME_TWO_PLUS_10_SECONDS) + + +@pytest.fixture +def serialization_adapter() -> BasicSerializationAdapter: + return BasicSerializationAdapter() + + +class TestBasicSerializationAdapter: + @pytest.fixture + def adapter(self) -> BasicSerializationAdapter: + return BasicSerializationAdapter() + + def test_empty_dict(self, adapter: BasicSerializationAdapter): + managed_entry = adapter.load_json( + json_str='{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {}}' + ) + assert managed_entry == snapshot( + ManagedEntry(value={}, created_at=FIXED_DATETIME_ONE, expires_at=FIXED_DATETIME_ONE_PLUS_10_SECONDS) + ) + + managed_entry = adapter.load_dict( + data={"created_at": FIXED_DATETIME_ONE_ISOFORMAT, "expires_at": FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT, "value": {}} + ) + assert managed_entry == snapshot( + ManagedEntry(value={}, created_at=FIXED_DATETIME_ONE, expires_at=FIXED_DATETIME_ONE_PLUS_10_SECONDS) + ) + + def test_entry_one(self, adapter: BasicSerializationAdapter): + assert adapter.dump_dict(entry=TEST_ENTRY_ONE) == snapshot( + { + "value": TEST_DATA_ONE, + "created_at": FIXED_DATETIME_ONE_ISOFORMAT, + "expires_at": FIXED_DATETIME_ONE_PLUS_10_SECONDS_ISOFORMAT, + } + ) + + assert adapter.dump_json(entry=TEST_ENTRY_ONE) == snapshot( + '{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"key_one": "value_one", "key_three": {"nested_key": "nested_value"}, "key_two": "value_two"}}' + ) + + assert adapter.load_dict(data=adapter.dump_dict(entry=TEST_ENTRY_ONE)) == snapshot(TEST_ENTRY_ONE) + assert adapter.load_json(json_str=adapter.dump_json(entry=TEST_ENTRY_ONE)) == snapshot(TEST_ENTRY_ONE) + + def test_entry_two(self, adapter: BasicSerializationAdapter): + assert adapter.dump_dict(entry=TEST_ENTRY_TWO) == snapshot( + { + "value": TEST_DATA_TWO, + "created_at": FIXED_DATETIME_TWO_ISOFORMAT, + "expires_at": FIXED_DATETIME_TWO_PLUS_10_SECONDS_ISOFORMAT, + } + ) + + assert adapter.dump_json(entry=TEST_ENTRY_TWO) == snapshot( + '{"created_at": "2025-01-01T00:00:01+00:00", "expires_at": "2025-01-01T00:00:11+00:00", "value": {"key_one": ["value_one", "value_two", "value_three"], "key_three": {"nested_key": "nested_value"}, "key_two": 123}}' + ) + + assert adapter.load_dict(data=adapter.dump_dict(entry=TEST_ENTRY_TWO)) == snapshot(TEST_ENTRY_TWO) + assert adapter.load_json(json_str=adapter.dump_json(entry=TEST_ENTRY_TWO)) == snapshot(TEST_ENTRY_TWO) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py index 1c02abda..53078503 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py @@ -17,6 +17,7 @@ from key_value.shared.errors import StoreSetupError from key_value.shared.type_checking.bear_spray import bear_enforce from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from key_value.shared.utils.time_to_live import prepare_entry_timestamps from typing_extensions import Self, override @@ -73,14 +74,23 @@ class BaseStore(KeyValueProtocol, ABC): _setup_collection_locks: defaultdict[str, Lock] _setup_collection_complete: defaultdict[str, bool] + _serialization_adapter: SerializationAdapter + _seed: FROZEN_SEED_DATA_TYPE default_collection: str - def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None: + def __init__( + self, + *, + serialization_adapter: SerializationAdapter | None = None, + default_collection: str | None = None, + seed: SEED_DATA_TYPE | None = None, + ) -> None: """Initialize the managed key-value store. Args: + serialization_adapter: The serialization adapter to use for the store. default_collection: The default collection to use if no collection is provided. Defaults to "default_collection". seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. @@ -97,6 +107,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP self.default_collection = default_collection or DEFAULT_COLLECTION_NAME + self._serialization_adapter = serialization_adapter or BasicSerializationAdapter() + if not hasattr(self, "_stable_api"): self._stable_api = False @@ -272,9 +284,9 @@ def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = No collection = collection or self.default_collection self.setup_collection(collection=collection) - (created_at, ttl_seconds, expires_at) = prepare_entry_timestamps(ttl=ttl) + (created_at, _, expires_at) = prepare_entry_timestamps(ttl=ttl) - managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) + managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) self._put_managed_entry(collection=collection, key=key, managed_entry=managed_entry) @@ -293,9 +305,7 @@ def put_many( (created_at, ttl_seconds, expires_at) = prepare_entry_timestamps(ttl=ttl) - managed_entries: list[ManagedEntry] = [ - ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values - ] + managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values] self._put_managed_entries( collection=collection, keys=keys, managed_entries=managed_entries, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py index 2c9fafaa..5e579c27 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py @@ -1,13 +1,13 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'multi_store.py' # DO NOT CHANGE! Change the original file instead. -import time from collections.abc import Callable +from datetime import timezone from pathlib import Path from typing import overload -from key_value.shared.utils.compound import compound_key -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, datetime +from key_value.shared.utils.serialization import BasicSerializationAdapter from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseContextManagerStore, BaseStore @@ -103,6 +103,7 @@ def default_disk_cache_factory(collection: str) -> Cache: self._cache = {} self._stable_api = True + self._serialization_adapter = BasicSerializationAdapter() super().__init__(default_collection=default_collection) @@ -112,32 +113,27 @@ def _setup_collection(self, *, collection: str) -> None: @override def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: - combo_key: str = compound_key(collection=collection, key=key) - expire_epoch: float - (managed_entry_str, expire_epoch) = self._cache[collection].get(key=combo_key, expire_time=True) # pyright: ignore[reportAny] + (managed_entry_str, expire_epoch) = self._cache[collection].get(key=key, expire_time=True) # pyright: ignore[reportAny] if not isinstance(managed_entry_str, str): return None - ttl = expire_epoch - time.time() if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: - combo_key: str = compound_key(collection=collection, key=key) - - _ = self._cache[collection].set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override def _delete_managed_entry(self, *, key: str, collection: str) -> bool: - combo_key: str = compound_key(collection=collection, key=key) - - return self._cache[collection].delete(key=combo_key, retry=True) + return self._cache[collection].delete(key=key, retry=True) def _sync_close(self) -> None: for cache in self._cache.values(): diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py index d75b32c9..ceabfa5d 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py @@ -1,7 +1,7 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. -import time +from datetime import datetime, timezone from pathlib import Path from typing import overload @@ -93,9 +93,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if not isinstance(managed_entry_str, str): return None - ttl = expire_epoch - time.time() if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry @@ -103,7 +104,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - _ = self._cache.set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override def _delete_managed_entry(self, *, key: str, collection: str) -> bool: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py index 8ab6f618..2bc36453 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py @@ -9,9 +9,10 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, NUMBERS, sanitize_string -from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str +from key_value.shared.utils.serialization import SerializationAdapter +from key_value.shared.utils.time_to_live import now_as_epoch from typing_extensions import override from key_value.sync.code_gen.stores.base import ( @@ -64,48 +65,50 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." -def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]: - document: dict[str, Any] = {"collection": collection, "key": key, "value": {}} +class ElasticsearchSerializationAdapter(SerializationAdapter): + """Adapter for Elasticsearch with support for native and string storage modes.""" - # Store in appropriate field based on mode - if native_storage: - document["value"]["flattened"] = managed_entry.value_as_dict - else: - document["value"]["string"] = managed_entry.value_as_json + _native_storage: bool - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at.isoformat() - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at.isoformat() + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the Elasticsearch adapter. - return document + Args: + native_storage: If True (default), store values as flattened dicts. + If False, store values as JSON strings. + """ + super().__init__() + self._native_storage = native_storage + self._date_format = "isoformat" + self._value_format = "dict" if native_storage else "string" -def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry: - value: dict[str, Any] = {} + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - raw_value = source.get("value") + data["value"] = {} - # Try flattened field first, fall back to string field - if not raw_value or not isinstance(raw_value, dict): - msg = "Value field not found or invalid type" - raise DeserializationError(msg) + if self._native_storage: + data["value"]["flattened"] = value + else: + data["value"]["string"] = value - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - value = verify_dict(obj=value_flattened) - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - if not isinstance(value_str, str): - msg = "Value in `value` field is not a string" - raise DeserializationError(msg) - value = load_from_json(value_str) - else: - msg = "Value field not found or invalid type" - raise DeserializationError(msg) + return data - created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at")) + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - return ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) + if "flattened" in value: + data["value"] = value["flattened"] + elif "string" in value: + data["value"] = value["string"] + else: + msg = "Value field not found in Elasticsearch document" + raise DeserializationError(message=msg) + + return data class ElasticsearchStore( @@ -121,6 +124,8 @@ class ElasticsearchStore( _native_storage: bool + _adapter: SerializationAdapter + @overload def __init__( self, *, elasticsearch_client: Elasticsearch, index_prefix: str, native_storage: bool = True, default_collection: str | None = None @@ -173,6 +178,7 @@ def __init__( self._index_prefix = index_prefix self._native_storage = native_storage self._is_serverless = False + self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -220,7 +226,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None try: - return source_to_managed_entry(source=source) + return self._adapter.load_dict(data=source) except DeserializationError: return None @@ -253,7 +259,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ continue try: - entries_by_id[doc_id] = source_to_managed_entry(source=source) + entries_by_id[doc_id] = self._adapter.load_dict(data=source) except DeserializationError as e: logger.error( "Failed to deserialize Elasticsearch document in batch operation", @@ -274,9 +280,7 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage index_name: str = self._sanitize_index_name(collection=collection) document_id: str = self._sanitize_document_id(key=key) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) try: _ = self._client.index(index=index_name, id=document_id, body=document, refresh=self._should_refresh_on_put) @@ -309,11 +313,10 @@ def _put_managed_entries( index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) operations.extend([index_action, document]) + try: _ = self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType] except ElasticsearchSerializationError as e: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py index f3eb41c8..db868392 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py @@ -69,7 +69,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if json_str is None: return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -78,7 +78,7 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) keyring.set_password(service_name=self._service_name, username=combo_key, password=json_str) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py index b206291a..89575d86 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py @@ -7,7 +7,8 @@ from typing import Any from key_value.shared.utils.managed_entry import ManagedEntry -from typing_extensions import Self, override +from key_value.shared.utils.serialization import BasicSerializationAdapter +from typing_extensions import override from key_value.sync.code_gen.stores.base import ( SEED_DATA_TYPE, @@ -32,13 +33,6 @@ class MemoryCacheEntry: expires_at: datetime | None - @classmethod - def from_managed_entry(cls, managed_entry: ManagedEntry) -> Self: - return cls(json_str=managed_entry.to_json(), expires_at=managed_entry.expires_at) - - def to_managed_entry(self) -> ManagedEntry: - return ManagedEntry.from_json(json_str=self.json_str) - def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, _now: float) -> float: """Calculate time-to-use for cache entries based on their expiration time.""" @@ -53,8 +47,6 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: return 1 -DEFAULT_MAX_ENTRIES_PER_COLLECTION = 10000 - DEFAULT_PAGE_SIZE = 10000 PAGE_LIMIT = 10000 @@ -62,13 +54,17 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: class MemoryCollection: _cache: TLRUCache[str, MemoryCacheEntry] - def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION): + def __init__(self, max_entries: int | None = None): """Initialize a fixed-size in-memory collection. Args: - max_entries: The maximum number of entries per collection. Defaults to 10,000 entries. + max_entries: The maximum number of entries per collection. Defaults to no limit. """ - self._cache = TLRUCache[str, MemoryCacheEntry](maxsize=max_entries, ttu=_memory_cache_ttu, getsizeof=_memory_cache_getsizeof) + self._cache = TLRUCache[str, MemoryCacheEntry]( + maxsize=max_entries if max_entries is not None else sys.maxsize, ttu=_memory_cache_ttu, getsizeof=_memory_cache_getsizeof + ) + + self._serialization_adapter = BasicSerializationAdapter() def get(self, key: str) -> ManagedEntry | None: managed_entry_str: MemoryCacheEntry | None = self._cache.get(key) @@ -76,12 +72,13 @@ def get(self, key: str) -> ManagedEntry | None: if managed_entry_str is None: return None - managed_entry: ManagedEntry = managed_entry_str.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str.json_str) return managed_entry def put(self, key: str, value: ManagedEntry) -> None: - self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value) + json_str: str = self._serialization_adapter.dump_json(entry=value) + self._cache[key] = MemoryCacheEntry(json_str=json_str, expires_at=value.expires_at) def delete(self, key: str) -> bool: return self._cache.pop(key, None) is not None @@ -99,23 +96,19 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol _cache: dict[str, MemoryCollection] def __init__( - self, - *, - max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, - default_collection: str | None = None, - seed: SEED_DATA_TYPE | None = None, + self, *, max_entries_per_collection: int | None = None, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None ): """Initialize a fixed-size in-memory store. Args: - max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000. + max_entries_per_collection: The maximum number of entries per collection. Defaults to no limit. default_collection: The default collection to use if no collection is provided. seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. Each value must be a mapping (dict) that will be stored as the entry's value. Seeding occurs lazily when each collection is first accessed. """ - self.max_entries_per_collection = max_entries_per_collection + self.max_entries_per_collection = max_entries_per_collection if max_entries_per_collection is not None else sys.maxsize self._cache = {} diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py index 753ffb46..8452c51f 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py @@ -2,13 +2,14 @@ # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. from collections.abc import Sequence -from datetime import datetime +from datetime import datetime, timezone from typing import Any, overload -from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict +from bson.errors import InvalidDocument +from key_value.shared.errors import DeserializationError, SerializationError +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string -from key_value.shared.utils.time_to_live import timezone +from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import Self, override from key_value.sync.code_gen.stores.base import ( @@ -19,7 +20,7 @@ ) try: - from pymongo import MongoClient + from pymongo import MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.database import Database from pymongo.results import DeleteResult # noqa: TC002 @@ -42,95 +43,56 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" -def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry. +class MongoDBSerializationAdapter(SerializationAdapter): + """Adapter for MongoDB with support for native and string storage modes.""" - This function deserializes a MongoDB document (created by `managed_entry_to_document`) back to a - ManagedEntry object. It supports both native BSON storage (dict in value.object field) and legacy - JSON string storage (string in value.string field) for migration support. - - Args: - document: The MongoDB document to convert. - - Returns: - A ManagedEntry object reconstructed from the document. - """ - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - data: dict[str, Any] = {} - - # The Value field is an object with two possible fields: `object` and `string` - # - `object`: The value is a native BSON dict - # - `string`: The value is a JSON string - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): - msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): - msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **data}, stringified_value=True) + _native_storage: bool - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the MongoDB adapter.""" + super().__init__() + self._native_storage = native_storage + self._date_format = "datetime" + self._value_format = "dict" if native_storage else "string" -def managed_entry_to_document(key: str, managed_entry: ManagedEntry, *, native_storage: bool = True) -> dict[str, Any]: - """Convert a ManagedEntry to a MongoDB document for storage. + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - This function serializes a ManagedEntry to a MongoDB document format, including the key and all - metadata (TTL, creation, and expiration timestamps). The value storage format depends on the - native_storage parameter. + data["value"] = {} - Args: - key: The key associated with this entry. - managed_entry: The ManagedEntry to serialize. - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. + if self._native_storage: + data["value"]["object"] = value + else: + data["value"]["string"] = value - Returns: - A MongoDB document dict containing the key, value, and all metadata. - """ - document: dict[str, Any] = {"key": key, "value": {}} + return data - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores. For example, - # Mongo can natively handle datetime objects which other stores cannot, if we don't convert to JSON, - # then using py-key-value with Mongo will return different values than if we used another store. - json_str = managed_entry.value_as_json + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - # Store in appropriate field based on mode - if native_storage: - document["value"]["object"] = managed_entry.value_as_dict - else: - document["value"]["string"] = json_str + if "object" in value: + data["value"] = value["object"] + elif "string" in value: + data["value"] = value["string"] + else: + msg = "Value field not found in MongoDB document" + raise DeserializationError(message=msg) - # Add metadata fields - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at + if date_created := data.get("created_at"): + if not isinstance(date_created, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(message=msg) + data["created_at"] = date_created.replace(tzinfo=timezone.utc) + if date_expires := data.get("expires_at"): + if not isinstance(date_expires, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(message=msg) + data["expires_at"] = date_expires.replace(tzinfo=timezone.utc) - return document + return data class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): @@ -139,7 +101,7 @@ class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, Ba _client: MongoClient[dict[str, Any]] _db: Database[dict[str, Any]] _collections_by_name: dict[str, Collection[dict[str, Any]]] - _native_storage: bool + _adapter: SerializationAdapter @overload def __init__( @@ -217,7 +179,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} - self._native_storage = native_storage + self._adapter = MongoDBSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -274,7 +236,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non sanitized_collection = self._sanitize_collection_name(collection=collection) if doc := self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): - return document_to_managed_entry(document=doc) + try: + return self._adapter.load_dict(data=doc) + except DeserializationError: + return None return None @@ -292,17 +257,28 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ for doc in cursor: if key := doc.get("key"): - managed_entries_by_key[key] = document_to_managed_entry(document=doc) + try: + managed_entries_by_key[key] = self._adapter.load_dict(data=doc) + except DeserializationError: + managed_entries_by_key[key] = None return [managed_entries_by_key[key] for key in keys] @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) sanitized_collection = self._sanitize_collection_name(collection=collection) - _ = self._collections_by_name[sanitized_collection].update_one(filter={"key": key}, update={"$set": mongo_doc}, upsert=True) + try: + # Ensure that the value is serializable to JSON + _ = managed_entry.value_as_json + _ = self._collections_by_name[sanitized_collection].update_one( + filter={"key": key}, update={"$set": {"key": key, **mongo_doc}}, upsert=True + ) + except InvalidDocument as e: + msg = f"Failed to update MongoDB document: {e}" + raise SerializationError(message=msg) from e @override def _put_managed_entries( @@ -320,14 +296,13 @@ def _put_managed_entries( sanitized_collection = self._sanitize_collection_name(collection=collection) - # Use bulk_write for efficient batch operations - from pymongo import UpdateOne - operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) - operations.append(UpdateOne(filter={"key": key}, update={"$set": mongo_doc}, upsert=True)) + operations.append( + UpdateOne(filter={"key": key}, update={"$set": {"collection": collection, "key": key, **mongo_doc}}, upsert=True) + ) _ = self._collections_by_name[sanitized_collection].bulk_write(operations) # pyright: ignore[reportUnknownMemberType] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py index d192cc0a..b20b7154 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py @@ -6,9 +6,11 @@ from typing import Any, overload from urllib.parse import urlparse +from key_value.shared.errors import DeserializationError from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -23,41 +25,11 @@ PAGE_LIMIT = 10000 -def managed_entry_to_json(managed_entry: ManagedEntry) -> str: - """Convert a ManagedEntry to a JSON string for Redis storage. - - This function serializes a ManagedEntry to JSON format including all metadata (TTL, creation, - and expiration timestamps). The serialization is designed to preserve all entry information - for round-trip conversion back to a ManagedEntry. - - Args: - managed_entry: The ManagedEntry to serialize. - - Returns: - A JSON string representation of the ManagedEntry with full metadata. - """ - return managed_entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - -def json_to_managed_entry(json_str: str) -> ManagedEntry: - """Convert a JSON string from Redis storage back to a ManagedEntry. - - This function deserializes a JSON string (created by `managed_entry_to_json`) back to a - ManagedEntry object, preserving all metadata including TTL, creation, and expiration timestamps. - - Args: - json_str: The JSON string to deserialize. - - Returns: - A ManagedEntry object reconstructed from the JSON string. - """ - return ManagedEntry.from_json(json_str=json_str, includes_metadata=True) - - class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" _client: Redis + _adapter: SerializationAdapter @overload def __init__(self, *, client: Redis, default_collection: str | None = None) -> None: ... @@ -108,6 +80,7 @@ def __init__( self._client = Redis(host=host, port=port, db=db, password=password, decode_responses=True) self._stable_api = True + self._adapter = BasicSerializationAdapter(date_format="isoformat", value_format="dict") super().__init__(default_collection=default_collection) @@ -120,9 +93,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if not isinstance(redis_response, str): return None - managed_entry: ManagedEntry = json_to_managed_entry(json_str=redis_response) - - return managed_entry + try: + return self._adapter.load_json(json_str=redis_response) + except DeserializationError: + return None @override def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -136,7 +110,10 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ entries: list[ManagedEntry | None] = [] for redis_response in redis_responses: if isinstance(redis_response, str): - entries.append(json_to_managed_entry(json_str=redis_response)) + try: + entries.append(self._adapter.load_json(json_str=redis_response)) + except DeserializationError: + entries.append(None) else: entries.append(None) @@ -146,7 +123,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value: str = self._adapter.dump_json(entry=managed_entry) if managed_entry.ttl is not None: # Redis does not support <= 0 TTLs @@ -172,10 +149,10 @@ def _put_managed_entries( if ttl is None: # If there is no TTL, we can just do a simple mset - mapping: dict[str, str] = { - compound_key(collection=collection, key=key): managed_entry_to_json(managed_entry=managed_entry) - for (key, managed_entry) in zip(keys, managed_entries, strict=True) - } + mapping: dict[str, str] = {} + for key, managed_entry in zip(keys, managed_entries, strict=True): + json_value = self._adapter.dump_json(entry=managed_entry) + mapping[compound_key(collection=collection, key=key)] = json_value self._client.mset(mapping=mapping) @@ -189,7 +166,7 @@ def _put_managed_entries( for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value = self._adapter.dump_json(entry=managed_entry) pipeline.setex(name=combo_key, time=ttl_seconds, value=json_value) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py index 4356ac3d..72d510b7 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py @@ -106,7 +106,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None managed_entry_str: str = value.decode("utf-8") - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str) + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) return managed_entry @@ -115,7 +115,7 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage self._fail_on_closed_store() combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) self._db[combo_key] = json_value.encode("utf-8") @@ -138,7 +138,7 @@ def _put_managed_entries( batch = WriteBatch() for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) batch.put(combo_key, json_value.encode("utf-8")) self._db.write(batch) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py index d24d1d19..8abac4f8 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py @@ -1,13 +1,14 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. +import sys from collections import defaultdict from dataclasses import dataclass from datetime import datetime from key_value.shared.utils.compound import compound_key, get_collections_from_compound_keys, get_keys_from_compound_keys -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json -from key_value.shared.utils.time_to_live import seconds_to +from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseDestroyStore, BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseStore @@ -22,20 +23,6 @@ class SimpleStoreEntry: created_at: datetime | None expires_at: datetime | None - @property - def current_ttl(self) -> float | None: - if self.expires_at is None: - return None - - return seconds_to(datetime=self.expires_at) - - def to_managed_entry(self) -> ManagedEntry: - managed_entry: ManagedEntry = ManagedEntry( - value=load_from_json(json_str=self.json_str), expires_at=self.expires_at, created_at=self.created_at - ) - - return managed_entry - class SimpleStore(BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyStore, BaseStore): """Simple managed dictionary-based key-value store for testing and development.""" @@ -44,18 +31,20 @@ class SimpleStore(BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDes _data: dict[str, SimpleStoreEntry] - def __init__(self, max_entries: int = DEFAULT_SIMPLE_STORE_MAX_ENTRIES, default_collection: str | None = None): + def __init__(self, max_entries: int | None = None, default_collection: str | None = None): """Initialize the simple store. Args: - max_entries: The maximum number of entries to store. Defaults to 10000. + max_entries: The maximum number of entries to store. Defaults to no limit. default_collection: The default collection to use if no collection is provided. """ - self.max_entries = max_entries + self.max_entries = max_entries if max_entries is not None else sys.maxsize self._data = defaultdict[str, SimpleStoreEntry]() + self._serialization_adapter = BasicSerializationAdapter(date_format=None) + super().__init__(default_collection=default_collection) @override @@ -67,7 +56,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if store_entry is None: return None - return store_entry.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=store_entry.json_str) + managed_entry.expires_at = store_entry.expires_at + managed_entry.created_at = store_entry.created_at + return managed_entry @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -77,7 +69,9 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage _ = self._data.pop(next(iter(self._data))) self._data[combo_key] = SimpleStoreEntry( - json_str=managed_entry.to_json(include_metadata=False), expires_at=managed_entry.expires_at, created_at=managed_entry.created_at + json_str=self._serialization_adapter.dump_json(entry=managed_entry), + expires_at=managed_entry.expires_at, + created_at=managed_entry.created_at, ) @override diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py index 6ec5c0a8..8c8c0a8e 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py @@ -97,7 +97,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non decoded_response: str = response.decode("utf-8") - return ManagedEntry.from_json(json_str=decoded_response) + return self._serialization_adapter.load_json(json_str=decoded_response) @override def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -112,7 +112,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ for response in responses: if isinstance(response, bytes): decoded_response: str = response.decode("utf-8") - entries.append(ManagedEntry.from_json(json_str=decoded_response)) + entries.append(self._serialization_adapter.load_json(json_str=decoded_response)) else: entries.append(None) @@ -122,7 +122,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) expiry: ExpirySet | None = ExpirySet(expiry_type=ExpiryType.SEC, value=int(managed_entry.ttl)) if managed_entry.ttl else None diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py index 6cea8ada..abffa9f2 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py @@ -83,7 +83,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non combo_key: str = compound_key(collection=collection, key=key) try: - response = self._kv_v2.read_secret(path=combo_key, mount_point=self._mount_point) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + response = self._kv_v2.read_secret(path=combo_key, mount_point=self._mount_point, raise_on_deleted_version=True) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] except InvalidPath: return None except Exception: @@ -99,13 +99,13 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None json_str: str = secret_data["value"] # pyright: ignore[reportUnknownVariableType] - return ManagedEntry.from_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] + return self._serialization_adapter.load_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) # Store the JSON string in a 'value' field secret_data = {"value": json_str} diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py index 30bec383..20024851 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py @@ -89,14 +89,14 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if not (json_str := get_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key)): return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: sanitized_key = self._sanitize_key(key=key) registry_path = self._get_registry_path(collection=collection) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) set_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key, value=json_str) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py index b18a66db..1bfc3f7f 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py @@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence from typing import Any, SupportsFloat -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.sync.code_gen.protocols.key_value import KeyValue @@ -55,7 +55,7 @@ def _should_compress(self, value: dict[str, Any]) -> bool: return False # Check size - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) return item_size >= self.min_size_to_compress def _compress_value(self, value: dict[str, Any]) -> dict[str, Any]: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py index 5c9ec62f..473219a3 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py @@ -5,7 +5,7 @@ from typing import Any, SupportsFloat from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.sync.code_gen.protocols.key_value import KeyValue @@ -68,7 +68,7 @@ def _within_size_limit(self, value: dict[str, Any], *, collection: str | None = EntryTooLargeError: If raise_on_too_large is True and the value exceeds max_size. """ - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) if self.min_size is not None and item_size < self.min_size: if self.raise_on_too_small: diff --git a/key-value/key-value-sync/tests/code_gen/stores/base.py b/key-value/key-value-sync/tests/code_gen/stores/base.py index 8f191c55..b78e279d 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/base.py +++ b/key-value/key-value-sync/tests/code_gen/stores/base.py @@ -171,7 +171,7 @@ def test_put_ttl_get_ttl(self, store: BaseStore): assert value == {"test": "test"} assert ttl is not None - assert ttl == IsFloat(approx=100) + assert ttl == IsFloat(approx=100, delta=2), f"TTL should be ~100, but is {ttl}" def test_negative_ttl(self, store: BaseStore): """Tests that a negative ttl will return None when getting the key.""" diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py index 994738a2..b904f35f 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py @@ -1,10 +1,14 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_disk.py' # DO NOT CHANGE! Change the original file instead. +import json import tempfile from collections.abc import Generator import pytest +from dirty_equals import IsDatetime +from diskcache.core import Cache +from inline_snapshot import snapshot from typing_extensions import override from key_value.sync.code_gen.stores.disk import DiskStore @@ -25,3 +29,22 @@ def store(self, disk_store: DiskStore) -> DiskStore: disk_store._cache.clear() # pyright: ignore[reportPrivateUsage] return disk_store + + @pytest.fixture + def disk_cache(self, disk_store: DiskStore) -> Cache: + return disk_store._cache # pyright: ignore[reportPrivateUsage] + + def test_value_stored(self, store: DiskStore, disk_cache: Cache): + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py index e6341075..ed68c109 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py @@ -1,16 +1,23 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_multi_disk.py' # DO NOT CHANGE! Change the original file instead. +import json import tempfile from collections.abc import Generator from pathlib import Path +from typing import TYPE_CHECKING import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from typing_extensions import override from key_value.sync.code_gen.stores.disk.multi_store import MultiDiskStore from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin +if TYPE_CHECKING: + from diskcache.core import Cache + TEST_SIZE_LIMIT = 100 * 1024 # 100KB @@ -27,3 +34,19 @@ def store(self, multi_disk_store: MultiDiskStore) -> MultiDiskStore: multi_disk_store._cache[collection].clear() # pyright: ignore[reportPrivateUsage] return multi_disk_store + + def test_value_stored(self, store: MultiDiskStore): + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + disk_cache: Cache = store._cache["test"] # pyright: ignore[reportPrivateUsage] + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot({"value": {"name": "Alice", "age": 30}, "created_at": IsDatetime(iso_string=True)}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py index 69851a03..499b194a 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py +++ b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py @@ -14,7 +14,7 @@ from key_value.sync.code_gen.stores.base import BaseStore from key_value.sync.code_gen.stores.elasticsearch import ElasticsearchStore -from key_value.sync.code_gen.stores.elasticsearch.store import managed_entry_to_document, source_to_managed_entry +from key_value.sync.code_gen.stores.elasticsearch.store import ElasticsearchSerializationAdapter from tests.code_gen.conftest import docker_container, should_skip_docker_tests from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -56,19 +56,14 @@ def test_managed_entry_document_conversion(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry) + adapter = ElasticsearchSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( - { - "collection": "test_collection", - "key": "test_key", - "value": {"string": '{"test": "test"}'}, - "created_at": "2025-01-01T00:00:00+00:00", - "expires_at": "2025-01-01T00:00:10+00:00", - } + {"value": {"string": '{"test": "test"}'}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00"} ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -81,19 +76,14 @@ def test_managed_entry_document_conversion_native_storage(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry, native_storage=True) + adapter = ElasticsearchSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( - { - "collection": "test_collection", - "key": "test_key", - "value": {"flattened": {"test": "test"}}, - "created_at": "2025-01-01T00:00:00+00:00", - "expires_at": "2025-01-01T00:00:10+00:00", - } + {"value": {"flattened": {"test": "test"}}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00"} ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -170,12 +160,7 @@ def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_cl response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"flattened": {"name": "Alice", "age": 30}}, - "created_at": IsStr(min_length=20, max_length=40), - } + {"value": {"flattened": {"name": "Alice", "age": 30}}, "created_at": IsStr(min_length=20, max_length=40)} ) # Test with TTL @@ -183,8 +168,6 @@ def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_cl response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"flattened": {"name": "Bob", "age": 25}}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), @@ -223,12 +206,7 @@ def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_client: response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"string": '{"age": 30, "name": "Alice"}'}, - "created_at": IsStr(min_length=20, max_length=40), - } + {"value": {"string": '{"age": 30, "name": "Alice"}'}, "created_at": IsStr(min_length=20, max_length=40)} ) # Test with TTL @@ -236,8 +214,6 @@ def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_client: response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"string": '{"age": 25, "name": "Bob"}'}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), diff --git a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py index a9a3a9a3..969012cf 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py @@ -17,7 +17,7 @@ from key_value.sync.code_gen.stores.base import BaseStore from key_value.sync.code_gen.stores.mongodb import MongoDBStore -from key_value.sync.code_gen.stores.mongodb.store import document_to_managed_entry, managed_entry_to_document +from key_value.sync.code_gen.stores.mongodb.store import MongoDBSerializationAdapter from tests.code_gen.conftest import docker_container, should_skip_docker_tests from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -51,18 +51,19 @@ def test_managed_entry_document_conversion_native_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) + + adapter = MongoDBSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "key": "test", "value": {"object": {"test": "test"}}, "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -75,18 +76,18 @@ def test_managed_entry_document_conversion_legacy_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) + adapter = MongoDBSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "key": "test", "value": {"string": '{"test": "test"}'}, "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -99,6 +100,7 @@ def clean_mongodb_database(store: MongoDBStore) -> None: _ = store._client.drop_database(name_or_database=store._db.name) # pyright: ignore[reportPrivateUsage] +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class BaseMongoDBStoreTests(ContextManagerStoreTestMixin, BaseStoreTests): """Base class for MongoDB store tests.""" diff --git a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py index 4620fba8..dd5841e8 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py +++ b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py @@ -1,20 +1,19 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_redis.py' # DO NOT CHANGE! Change the original file instead. +import json from collections.abc import Generator -from datetime import datetime, timedelta, timezone +from typing import Any import pytest -from dirty_equals import IsFloat +from dirty_equals import IsDatetime from inline_snapshot import snapshot from key_value.shared.stores.wait import wait_for_true -from key_value.shared.utils.managed_entry import ManagedEntry from redis.client import Redis from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseStore from key_value.sync.code_gen.stores.redis import RedisStore -from key_value.sync.code_gen.stores.redis.store import json_to_managed_entry, managed_entry_to_json from tests.code_gen.conftest import docker_container, should_skip_docker_tests from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -28,25 +27,6 @@ REDIS_VERSIONS_TO_TEST = ["4.0.0", "7.0.0"] -def test_managed_entry_document_conversion(): - created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) - expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_json(managed_entry=managed_entry) - - assert document == snapshot( - '{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}' - ) - - round_trip_managed_entry = json_to_managed_entry(json_str=document) - - assert round_trip_managed_entry.value == managed_entry.value - assert round_trip_managed_entry.created_at == created_at - assert round_trip_managed_entry.ttl == IsFloat(lt=0) - assert round_trip_managed_entry.expires_at == expires_at - - def ping_redis() -> bool: client: Redis = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) try: @@ -59,6 +39,10 @@ class RedisFailedToStartError(Exception): pass +def get_client_from_store(store: RedisStore) -> Redis: + return store._client # pyright: ignore[reportPrivateUsage] + + @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestRedisStore(ContextManagerStoreTestMixin, BaseStoreTests): @pytest.fixture(autouse=True, scope="session", params=REDIS_VERSIONS_TO_TEST) @@ -78,14 +62,18 @@ def store(self, setup_redis: RedisStore) -> RedisStore: """Create a Redis store for testing.""" # Create the store with test database redis_store = RedisStore(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) - _ = redis_store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = get_client_from_store(store=redis_store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] return redis_store + @pytest.fixture + def redis_client(self, store: RedisStore) -> Redis: + return get_client_from_store(store=store) + def test_redis_url_connection(self): """Test Redis store creation with URL.""" redis_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" store = RedisStore(url=redis_url) - _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] store.put(collection="test", key="url_test", value={"test": "value"}) result = store.get(collection="test", key="url_test") assert result == {"test": "value"} @@ -97,11 +85,47 @@ def test_redis_client_connection(self): client = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) store = RedisStore(client=client) - _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] store.put(collection="test", key="client_test", value={"test": "value"}) result = store.get(collection="test", key="client_test") assert result == {"test": "value"} + def test_redis_document_format(self, store: RedisStore, redis_client: Redis): + """Test Redis store document format.""" + store.put(collection="test", key="document_format_test_1", value={"test_1": "value_1"}) + store.put(collection="test", key="document_format_test_2", value={"test_2": "value_2"}, ttl=10) + + raw_documents: Any = redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) + raw_documents_dicts: list[dict[str, Any]] = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + {"created_at": IsDatetime(iso_string=True), "value": {"test_1": "value_1"}}, + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test_2": "value_2"}}, + ] + ) + + store.put_many( + collection="test", + keys=["document_format_test_3", "document_format_test_4"], + values=[{"test_3": "value_3"}, {"test_4": "value_4"}], + ttl=10, + ) + raw_documents = redis_client.mget(keys=["test::document_format_test_3", "test::document_format_test_4"]) + raw_documents_dicts = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test_3": "value_3"}}, + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test_4": "value_4"}}, + ] + ) + + store.put(collection="test", key="document_format_test", value={"test": "value"}, ttl=10) + raw_document: Any = redis_client.get(name="test::document_format_test") + raw_document_dict = json.loads(raw_document) + assert raw_document_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test": "value"}} + ) + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py index 2e04cd08..58e3fe48 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py @@ -1,11 +1,15 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_rocksdb.py' # DO NOT CHANGE! Change the original file instead. +import json from collections.abc import Generator from pathlib import Path from tempfile import TemporaryDirectory import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot +from rocksdict import Rdict from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseStore @@ -13,6 +17,7 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestRocksDBStore(ContextManagerStoreTestMixin, BaseStoreTests): @override @pytest.fixture @@ -62,3 +67,24 @@ def test_rocksdb_db_connection(self): @pytest.mark.skip(reason="Local disk stores are unbounded") @override def test_not_unbounded(self, store: BaseStore): ... + + @pytest.fixture + def rocksdb_client(self, store: RocksDBStore) -> Rdict: + return store._db # pyright: ignore[reportPrivateUsage] + + def test_value_stored(self, store: RocksDBStore, rocksdb_client: Rdict): + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py b/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py index 1ee92614..85600ef8 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py +++ b/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py @@ -8,6 +8,7 @@ from tests.code_gen.stores.base import BaseStoreTests +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestSimpleStore(BaseStoreTests): @override @pytest.fixture diff --git a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py index 69349194..0e9c2cbb 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py +++ b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py @@ -2,9 +2,12 @@ # from the original file 'test_valkey.py' # DO NOT CHANGE! Change the original file instead. import contextlib +import json from collections.abc import Generator import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from key_value.shared.stores.wait import wait_for_true from typing_extensions import override @@ -82,3 +85,26 @@ def store(self, setup_valkey: None): @pytest.mark.skip(reason="Distributed Caches are unbounded") @override def test_not_unbounded(self, store: BaseStore): ... + + def test_value_stored(self, store: BaseStore): + from key_value.sync.code_gen.stores.valkey import ValkeyStore + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + assert isinstance(store, ValkeyStore) + + valkey_client = store._connected_client # pyright: ignore[reportPrivateUsage] + assert valkey_client is not None + value = valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py b/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py index c7a5fd74..6eddf1f1 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py +++ b/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py @@ -29,6 +29,7 @@ class VaultFailedToStartError(Exception): @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestVaultStore(BaseStoreTests): def get_vault_client(self): import hvac diff --git a/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py b/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py index 65105b16..9d1b7be4 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py +++ b/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py @@ -17,6 +17,7 @@ @pytest.mark.skipif(condition=not detect_on_windows(), reason="WindowsRegistryStore is only available on Windows") +@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.") class TestWindowsRegistryStore(BaseStoreTests): def cleanup(self): from winreg import HKEY_CURRENT_USER diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py index e404249a..74bbba54 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py @@ -3,6 +3,7 @@ # DO NOT CHANGE! Change the original file instead. import pytest from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.sync.code_gen.stores.memory.store import MemoryStore @@ -140,12 +141,8 @@ def test_put_many_all_too_large_without_raise(self, memory_store: MemoryStore): def test_exact_size_limit(self, memory_store: MemoryStore): # First, determine the exact size of a small value - from key_value.shared.utils.managed_entry import ManagedEntry - test_value = {"test": "value"} - managed_entry = ManagedEntry(value=test_value) - json_str = managed_entry.to_json() - exact_size = len(json_str.encode("utf-8")) + exact_size = estimate_serialized_size(value=test_value) # Create a store with exact size limit limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size, raise_on_too_large=True)