From 796e8ecbe99635a95fe2a2fb1a768cd971483104 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Wed, 29 Oct 2025 03:01:51 +0000 Subject: [PATCH 01/11] refactor: implement serialization adapter pattern for MongoDB and Redis stores MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit introduces a clean adapter pattern for store-specific serialization: **Key Changes:** - Add SerializationAdapter base class with to_storage() and from_storage() methods - Implement MongoDBAdapter with native BSON storage support - Implement RedisAdapter for JSON string storage - Add estimate_serialized_size() utility function to managed_entry.py - Update CompressionWrapper and LimitSizeWrapper to use estimate_serialized_size() - Add error handling for deserialization failures (return None instead of crashing) - Add type checking for adapter outputs to catch contract violations early **Benefits:** - Consistent serialization interface across stores - Better separation of concerns (serialization logic in adapters) - More efficient size calculations without ManagedEntry instantiation - Easier to test and maintain - Improved error handling Co-authored-by: William Easton 🤖 Generated with [Claude Code](https://claude.com/claude-code) --- .../src/key_value/aio/stores/mongodb/store.py | 149 ++++++++- .../src/key_value/aio/stores/redis/store.py | 97 +++++- .../aio/wrappers/compression/wrapper.py | 4 +- .../aio/wrappers/limit_size/wrapper.py | 4 +- .../key_value/shared/utils/managed_entry.py | 16 + .../key_value/sync/code_gen/stores/base.py | 1 + .../sync/code_gen/stores/mongodb/store.py | 149 ++++++++- .../sync/code_gen/stores/redis/store.py | 97 +++++- .../code_gen/wrappers/compression/wrapper.py | 4 +- .../code_gen/wrappers/limit_size/wrapper.py | 4 +- .../tests/code_gen/adapters/test_dataclass.py | 94 +++--- .../tests/code_gen/adapters/test_pydantic.py | 121 ++++---- .../tests/code_gen/adapters/test_raise.py | 38 +-- .../key-value-sync/tests/code_gen/cases.py | 50 +-- .../key-value-sync/tests/code_gen/conftest.py | 89 +++--- .../tests/code_gen/protocols/test_types.py | 19 +- .../tests/code_gen/stores/base.py | 285 ++++++++++-------- .../tests/code_gen/stores/conftest.py | 2 +- .../tests/code_gen/stores/disk/test_disk.py | 6 +- .../code_gen/stores/disk/test_multi_disk.py | 6 +- .../elasticsearch/test_elasticsearch.py | 221 ++++++-------- .../code_gen/stores/keyring/test_keyring.py | 20 +- .../code_gen/stores/memory/test_memory.py | 6 +- .../code_gen/stores/mongodb/test_mongodb.py | 170 +++++------ .../tests/code_gen/stores/redis/test_redis.py | 58 ++-- .../code_gen/stores/rocksdb/test_rocksdb.py | 43 +-- .../code_gen/stores/simple/test_store.py | 1 + .../code_gen/stores/valkey/test_valkey.py | 44 +-- .../tests/code_gen/stores/vault/test_vault.py | 57 ++-- .../windows_registry/test_windows_registry.py | 24 +- .../stores/wrappers/test_compression.py | 110 +++---- .../stores/wrappers/test_default_value.py | 82 +++-- .../stores/wrappers/test_encryption.py | 154 +++++----- .../code_gen/stores/wrappers/test_fallback.py | 53 ++-- .../stores/wrappers/test_limit_size.py | 149 ++++----- .../code_gen/stores/wrappers/test_logging.py | 216 ++++--------- .../stores/wrappers/test_passthrough_cache.py | 10 +- .../stores/wrappers/test_prefix_collection.py | 3 +- .../stores/wrappers/test_prefix_key.py | 3 +- .../stores/wrappers/test_read_only.py | 61 ++-- .../code_gen/stores/wrappers/test_retry.py | 44 +-- .../code_gen/stores/wrappers/test_routing.py | 61 ++-- .../stores/wrappers/test_single_collection.py | 3 +- .../stores/wrappers/test_statistics.py | 1 + .../stores/wrappers/test_ttl_clamp.py | 36 ++- 45 files changed, 1643 insertions(+), 1222 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py index 6db89992..84f19957 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload @@ -35,6 +36,128 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" +class SerializationAdapter(ABC): + """Base class for store-specific serialization adapters. + + Adapters encapsulate the logic for converting between ManagedEntry objects + and store-specific storage formats. This provides a consistent interface + while allowing each store to optimize its serialization strategy. + """ + + @abstractmethod + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: + """Convert a ManagedEntry to the store's storage format. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: Optional collection name. + + Returns: + The serialized representation (dict or str depending on store). + """ + ... + + @abstractmethod + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert stored data back to a ManagedEntry. + + Args: + data: The stored representation to deserialize. + + Returns: + A ManagedEntry reconstructed from storage. + + Raises: + DeserializationError: If the data cannot be deserialized. + """ + ... + + +class MongoDBAdapter(SerializationAdapter): + """MongoDB-specific serialization adapter. + + Stores entries with native BSON datetime types for TTL indexing, + while maintaining the value.object/value.string structure for compatibility. + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the MongoDB adapter. + + Args: + native_storage: If True (default), store value as native BSON dict in value.object field. + If False, store as JSON string in value.string field for backward compatibility. + """ + self.native_storage = native_storage + + @override + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: + """Convert a ManagedEntry to a MongoDB document.""" + document: dict[str, Any] = {"key": key, "value": {}} + + # We convert to JSON even if we don't need to, this ensures that the value we were provided + # can be serialized to JSON which helps ensure compatibility across stores + json_str = entry.value_as_json + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["object"] = entry.value_as_dict + else: + document["value"]["string"] = json_str + + # Add metadata fields as BSON datetimes for TTL indexing + if entry.created_at: + document["created_at"] = entry.created_at + if entry.expires_at: + document["expires_at"] = entry.expires_at + + return document + + @override + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a MongoDB document back to a ManagedEntry.""" + if not isinstance(data, dict): + msg = "Expected MongoDB document to be a dict" + raise DeserializationError(msg) + + document = data + + if not (value_field := document.get("value")): + msg = "Value field not found" + raise DeserializationError(msg) + + if not isinstance(value_field, dict): + msg = "Expected `value` field to be an object" + raise DeserializationError(msg) + + value_holder: dict[str, Any] = verify_dict(obj=value_field) + + entry_data: dict[str, Any] = {} + + # Mongo stores datetimes without timezones as UTC so we mark them as UTC + if created_at_datetime := document.get("created_at"): + if not isinstance(created_at_datetime, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) + + if expires_at_datetime := document.get("expires_at"): + if not isinstance(expires_at_datetime, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) + + # Support both native (object) and legacy (string) storage + if value_object := value_holder.get("object"): + return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) + + if value_string := value_holder.get("string"): + return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) + + msg = "Expected `value` field to be an object with `object` or `string` subfield" + raise DeserializationError(msg) + + def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: """Convert a MongoDB document back to a ManagedEntry. @@ -132,7 +255,7 @@ class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, Ba _client: AsyncMongoClient[dict[str, Any]] _db: AsyncDatabase[dict[str, Any]] _collections_by_name: dict[str, AsyncCollection[dict[str, Any]]] - _native_storage: bool + _adapter: SerializationAdapter @overload def __init__( @@ -210,7 +333,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} - self._native_storage = native_storage + self._adapter = MongoDBAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -267,7 +390,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry sanitized_collection = self._sanitize_collection_name(collection=collection) if doc := await self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): - return document_to_managed_entry(document=doc) + try: + return self._adapter.from_storage(data=doc) + except DeserializationError: + return None return None @@ -285,7 +411,10 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> async for doc in cursor: if key := doc.get("key"): - managed_entries_by_key[key] = document_to_managed_entry(document=doc) + try: + managed_entries_by_key[key] = self._adapter.from_storage(data=doc) + except DeserializationError: + managed_entries_by_key[key] = None return [managed_entries_by_key[key] for key in keys] @@ -297,7 +426,11 @@ async def _put_managed_entry( collection: str, managed_entry: ManagedEntry, ) -> None: - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(mongo_doc, dict): + msg = "MongoDB adapter must return dict" + raise TypeError(msg) sanitized_collection = self._sanitize_collection_name(collection=collection) @@ -328,7 +461,11 @@ async def _put_managed_entries( operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(mongo_doc, dict): + msg = "MongoDB adapter must return dict" + raise TypeError(msg) operations.append( UpdateOne( diff --git a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py index c2a21efc..6d060ecc 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py @@ -1,8 +1,10 @@ +from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload from urllib.parse import urlparse +from key_value.shared.errors import DeserializationError from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry @@ -20,6 +22,64 @@ PAGE_LIMIT = 10000 +class SerializationAdapter(ABC): + """Base class for store-specific serialization adapters. + + Adapters encapsulate the logic for converting between ManagedEntry objects + and store-specific storage formats. + """ + + @abstractmethod + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: + """Convert a ManagedEntry to the store's storage format. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: Optional collection name. + + Returns: + The serialized representation (dict or str depending on store). + """ + ... + + @abstractmethod + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert stored data back to a ManagedEntry. + + Args: + data: The stored representation to deserialize. + + Returns: + A ManagedEntry reconstructed from storage. + + Raises: + DeserializationError: If the data cannot be deserialized. + """ + ... + + +class RedisAdapter(SerializationAdapter): + """Redis-specific serialization adapter. + + Stores entries as JSON strings in Redis. + """ + + @override + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: + """Convert a ManagedEntry to a JSON string for Redis storage.""" + return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) + + @override + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a JSON string from Redis storage back to a ManagedEntry.""" + if not isinstance(data, str): + msg = "Expected Redis data to be a string" + raise DeserializationError(msg) + + return ManagedEntry.from_json(json_str=data, includes_metadata=True) + + def managed_entry_to_json(managed_entry: ManagedEntry) -> str: """Convert a ManagedEntry to a JSON string for Redis storage. @@ -55,6 +115,7 @@ class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerSto """Redis-based key-value store.""" _client: Redis + _adapter: SerializationAdapter @overload def __init__(self, *, client: Redis, default_collection: str | None = None) -> None: ... @@ -111,6 +172,7 @@ def __init__( ) self._stable_api = True + self._adapter = RedisAdapter() super().__init__(default_collection=default_collection) @@ -123,9 +185,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not isinstance(redis_response, str): return None - managed_entry: ManagedEntry = json_to_managed_entry(json_str=redis_response) - - return managed_entry + try: + return self._adapter.from_storage(data=redis_response) + except DeserializationError: + return None @override async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -139,7 +202,10 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> entries: list[ManagedEntry | None] = [] for redis_response in redis_responses: if isinstance(redis_response, str): - entries.append(json_to_managed_entry(json_str=redis_response)) + try: + entries.append(self._adapter.from_storage(data=redis_response)) + except DeserializationError: + entries.append(None) else: entries.append(None) @@ -155,7 +221,11 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(json_value, str): + msg = "Redis adapter must return str" + raise TypeError(msg) if managed_entry.ttl is not None: # Redis does not support <= 0 TTLs @@ -181,10 +251,13 @@ async def _put_managed_entries( if ttl is None: # If there is no TTL, we can just do a simple mset - mapping: dict[str, str] = { - compound_key(collection=collection, key=key): managed_entry_to_json(managed_entry=managed_entry) - for key, managed_entry in zip(keys, managed_entries, strict=True) - } + mapping: dict[str, str] = {} + for key, managed_entry in zip(keys, managed_entries, strict=True): + json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + if not isinstance(json_value, str): + msg = "Redis adapter must return str" + raise TypeError(msg) + mapping[compound_key(collection=collection, key=key)] = json_value await self._client.mset(mapping=mapping) @@ -198,7 +271,11 @@ async def _put_managed_entries( for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(json_value, str): + msg = "Redis adapter must return str" + raise TypeError(msg) pipeline.setex(name=combo_key, time=ttl_seconds, value=json_value) diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py index 8db9f799..83b58ff8 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/compression/wrapper.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from typing import Any, SupportsFloat -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.aio.protocols.key_value import AsyncKeyValue @@ -56,7 +56,7 @@ def _should_compress(self, value: dict[str, Any]) -> bool: return False # Check size - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) return item_size >= self.min_size_to_compress def _compress_value(self, value: dict[str, Any]) -> dict[str, Any]: diff --git a/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py b/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py index c537f949..9cea5d0c 100644 --- a/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py +++ b/key-value/key-value-aio/src/key_value/aio/wrappers/limit_size/wrapper.py @@ -2,7 +2,7 @@ from typing import Any, SupportsFloat from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.aio.protocols.key_value import AsyncKeyValue @@ -65,7 +65,7 @@ def _within_size_limit(self, value: dict[str, Any], *, collection: str | None = EntryTooLargeError: If raise_on_too_large is True and the value exceeds max_size. """ - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) if self.min_size is not None and item_size < self.min_size: if self.raise_on_too_small: diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index 60cafb60..90cb5490 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -168,3 +168,19 @@ def verify_dict(obj: Any) -> dict[str, Any]: raise DeserializationError(msg) return cast(typ="dict[str, Any]", val=obj) + + +def estimate_serialized_size(value: Mapping[str, Any]) -> int: + """Estimate the serialized size of a value without creating a ManagedEntry. + + This function provides a more efficient way to estimate the size of a value + when serialized to JSON, without the overhead of creating a full ManagedEntry object. + This is useful for size-based checks in wrappers. + + Args: + value: The value mapping to estimate the size for. + + Returns: + The estimated size in bytes when serialized to JSON. + """ + return len(dump_to_json(obj=dict(value))) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py index 1c02abda..e57cc10c 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py @@ -252,6 +252,7 @@ def _put_managed_entries( created_at: datetime, expires_at: datetime | None, ) -> None: + """Store multiple managed entries by key in the specified collection. Args: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py index 753ffb46..1d4c18e0 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py @@ -1,6 +1,7 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. +from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload @@ -42,6 +43,128 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" +class SerializationAdapter(ABC): + """Base class for store-specific serialization adapters. + + Adapters encapsulate the logic for converting between ManagedEntry objects + and store-specific storage formats. This provides a consistent interface + while allowing each store to optimize its serialization strategy. + """ + + @abstractmethod + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: + """Convert a ManagedEntry to the store's storage format. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: Optional collection name. + + Returns: + The serialized representation (dict or str depending on store). + """ + ... + + @abstractmethod + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert stored data back to a ManagedEntry. + + Args: + data: The stored representation to deserialize. + + Returns: + A ManagedEntry reconstructed from storage. + + Raises: + DeserializationError: If the data cannot be deserialized. + """ + ... + + +class MongoDBAdapter(SerializationAdapter): + """MongoDB-specific serialization adapter. + + Stores entries with native BSON datetime types for TTL indexing, + while maintaining the value.object/value.string structure for compatibility. + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the MongoDB adapter. + + Args: + native_storage: If True (default), store value as native BSON dict in value.object field. + If False, store as JSON string in value.string field for backward compatibility. + """ + self.native_storage = native_storage + + @override + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: + """Convert a ManagedEntry to a MongoDB document.""" + document: dict[str, Any] = {"key": key, "value": {}} + + # We convert to JSON even if we don't need to, this ensures that the value we were provided + # can be serialized to JSON which helps ensure compatibility across stores + json_str = entry.value_as_json + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["object"] = entry.value_as_dict + else: + document["value"]["string"] = json_str + + # Add metadata fields as BSON datetimes for TTL indexing + if entry.created_at: + document["created_at"] = entry.created_at + if entry.expires_at: + document["expires_at"] = entry.expires_at + + return document + + @override + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a MongoDB document back to a ManagedEntry.""" + if not isinstance(data, dict): + msg = "Expected MongoDB document to be a dict" + raise DeserializationError(msg) + + document = data + + if not (value_field := document.get("value")): + msg = "Value field not found" + raise DeserializationError(msg) + + if not isinstance(value_field, dict): + msg = "Expected `value` field to be an object" + raise DeserializationError(msg) + + value_holder: dict[str, Any] = verify_dict(obj=value_field) + + entry_data: dict[str, Any] = {} + + # Mongo stores datetimes without timezones as UTC so we mark them as UTC + if created_at_datetime := document.get("created_at"): + if not isinstance(created_at_datetime, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) + + if expires_at_datetime := document.get("expires_at"): + if not isinstance(expires_at_datetime, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) + + # Support both native (object) and legacy (string) storage + if value_object := value_holder.get("object"): + return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) + + if value_string := value_holder.get("string"): + return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) + + msg = "Expected `value` field to be an object with `object` or `string` subfield" + raise DeserializationError(msg) + + def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: """Convert a MongoDB document back to a ManagedEntry. @@ -139,7 +262,7 @@ class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, Ba _client: MongoClient[dict[str, Any]] _db: Database[dict[str, Any]] _collections_by_name: dict[str, Collection[dict[str, Any]]] - _native_storage: bool + _adapter: SerializationAdapter @overload def __init__( @@ -217,7 +340,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} - self._native_storage = native_storage + self._adapter = MongoDBAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -274,7 +397,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non sanitized_collection = self._sanitize_collection_name(collection=collection) if doc := self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): - return document_to_managed_entry(document=doc) + try: + return self._adapter.from_storage(data=doc) + except DeserializationError: + return None return None @@ -292,13 +418,20 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ for doc in cursor: if key := doc.get("key"): - managed_entries_by_key[key] = document_to_managed_entry(document=doc) + try: + managed_entries_by_key[key] = self._adapter.from_storage(data=doc) + except DeserializationError: + managed_entries_by_key[key] = None return [managed_entries_by_key[key] for key in keys] @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(mongo_doc, dict): + msg = "MongoDB adapter must return dict" + raise TypeError(msg) sanitized_collection = self._sanitize_collection_name(collection=collection) @@ -325,7 +458,11 @@ def _put_managed_entries( operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc: dict[str, Any] = managed_entry_to_document(key=key, managed_entry=managed_entry, native_storage=self._native_storage) + mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(mongo_doc, dict): + msg = "MongoDB adapter must return dict" + raise TypeError(msg) operations.append(UpdateOne(filter={"key": key}, update={"$set": mongo_doc}, upsert=True)) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py index d192cc0a..cd302601 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py @@ -1,11 +1,13 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. +from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload from urllib.parse import urlparse +from key_value.shared.errors import DeserializationError from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry @@ -23,6 +25,64 @@ PAGE_LIMIT = 10000 +class SerializationAdapter(ABC): + """Base class for store-specific serialization adapters. + + Adapters encapsulate the logic for converting between ManagedEntry objects + and store-specific storage formats. + """ + + @abstractmethod + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: + """Convert a ManagedEntry to the store's storage format. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: Optional collection name. + + Returns: + The serialized representation (dict or str depending on store). + """ + ... + + @abstractmethod + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert stored data back to a ManagedEntry. + + Args: + data: The stored representation to deserialize. + + Returns: + A ManagedEntry reconstructed from storage. + + Raises: + DeserializationError: If the data cannot be deserialized. + """ + ... + + +class RedisAdapter(SerializationAdapter): + """Redis-specific serialization adapter. + + Stores entries as JSON strings in Redis. + """ + + @override + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: + """Convert a ManagedEntry to a JSON string for Redis storage.""" + return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) + + @override + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a JSON string from Redis storage back to a ManagedEntry.""" + if not isinstance(data, str): + msg = "Expected Redis data to be a string" + raise DeserializationError(msg) + + return ManagedEntry.from_json(json_str=data, includes_metadata=True) + + def managed_entry_to_json(managed_entry: ManagedEntry) -> str: """Convert a ManagedEntry to a JSON string for Redis storage. @@ -58,6 +118,7 @@ class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerSto """Redis-based key-value store.""" _client: Redis + _adapter: SerializationAdapter @overload def __init__(self, *, client: Redis, default_collection: str | None = None) -> None: ... @@ -108,6 +169,7 @@ def __init__( self._client = Redis(host=host, port=port, db=db, password=password, decode_responses=True) self._stable_api = True + self._adapter = RedisAdapter() super().__init__(default_collection=default_collection) @@ -120,9 +182,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if not isinstance(redis_response, str): return None - managed_entry: ManagedEntry = json_to_managed_entry(json_str=redis_response) - - return managed_entry + try: + return self._adapter.from_storage(data=redis_response) + except DeserializationError: + return None @override def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -136,7 +199,10 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ entries: list[ManagedEntry | None] = [] for redis_response in redis_responses: if isinstance(redis_response, str): - entries.append(json_to_managed_entry(json_str=redis_response)) + try: + entries.append(self._adapter.from_storage(data=redis_response)) + except DeserializationError: + entries.append(None) else: entries.append(None) @@ -146,7 +212,11 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(json_value, str): + msg = "Redis adapter must return str" + raise TypeError(msg) if managed_entry.ttl is not None: # Redis does not support <= 0 TTLs @@ -172,10 +242,13 @@ def _put_managed_entries( if ttl is None: # If there is no TTL, we can just do a simple mset - mapping: dict[str, str] = { - compound_key(collection=collection, key=key): managed_entry_to_json(managed_entry=managed_entry) - for (key, managed_entry) in zip(keys, managed_entries, strict=True) - } + mapping: dict[str, str] = {} + for key, managed_entry in zip(keys, managed_entries, strict=True): + json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + if not isinstance(json_value, str): + msg = "Redis adapter must return str" + raise TypeError(msg) + mapping[compound_key(collection=collection, key=key)] = json_value self._client.mset(mapping=mapping) @@ -189,7 +262,11 @@ def _put_managed_entries( for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry_to_json(managed_entry=managed_entry) + json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + + if not isinstance(json_value, str): + msg = "Redis adapter must return str" + raise TypeError(msg) pipeline.setex(name=combo_key, time=ttl_seconds, value=json_value) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py index b18a66db..1bfc3f7f 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/compression/wrapper.py @@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence from typing import Any, SupportsFloat -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.sync.code_gen.protocols.key_value import KeyValue @@ -55,7 +55,7 @@ def _should_compress(self, value: dict[str, Any]) -> bool: return False # Check size - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) return item_size >= self.min_size_to_compress def _compress_value(self, value: dict[str, Any]) -> dict[str, Any]: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py index 5c9ec62f..473219a3 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/wrappers/limit_size/wrapper.py @@ -5,7 +5,7 @@ from typing import Any, SupportsFloat from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.sync.code_gen.protocols.key_value import KeyValue @@ -68,7 +68,7 @@ def _within_size_limit(self, value: dict[str, Any], *, collection: str | None = EntryTooLargeError: If raise_on_too_large is True and the value exceeds max_size. """ - item_size: int = len(ManagedEntry(value=value).to_json()) + item_size: int = estimate_serialized_size(value=value) if self.min_size is not None and item_size < self.min_size: if self.raise_on_too_small: diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py b/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py index ec618f3d..949f27c4 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py @@ -57,90 +57,95 @@ class Order: product: Product paid: bool = False - FIXED_CREATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=12, minute=0, second=0, tzinfo=timezone.utc) FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) -SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) -SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) -SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10) -SAMPLE_ADDRESS: Address = Address(street="123 Main St", city="Springfield", zip_code="12345") -SAMPLE_USER_WITH_ADDRESS: UserWithAddress = UserWithAddress(name="John Doe", age=30, address=SAMPLE_ADDRESS) +SAMPLE_USER: User = User(name='John Doe', email='john.doe@example.com', age=30) +SAMPLE_USER_2: User = User(name='Jane Doe', email='jane.doe@example.com', age=25) +SAMPLE_PRODUCT: Product = Product(name='Widget', price=29.99, quantity=10) +SAMPLE_ADDRESS: Address = Address(street='123 Main St', city='Springfield', zip_code='12345') +SAMPLE_USER_WITH_ADDRESS: UserWithAddress = UserWithAddress(name='John Doe', age=30, address=SAMPLE_ADDRESS) SAMPLE_ORDER: Order = Order(created_at=FIXED_CREATED_AT, updated_at=FIXED_UPDATED_AT, user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) -TEST_COLLECTION: str = "test_collection" -TEST_KEY: str = "test_key" -TEST_KEY_2: str = "test_key_2" +TEST_COLLECTION: str = 'test_collection' +TEST_KEY: str = 'test_key' +TEST_KEY_2: str = 'test_key_2' class TestDataclassAdapter: + @pytest.fixture def store(self) -> MemoryStore: return MemoryStore() + @pytest.fixture def user_adapter(self, store: MemoryStore) -> DataclassAdapter[User]: return DataclassAdapter[User](key_value=store, dataclass_type=User) + @pytest.fixture def updated_user_adapter(self, store: MemoryStore) -> DataclassAdapter[UpdatedUser]: return DataclassAdapter[UpdatedUser](key_value=store, dataclass_type=UpdatedUser) + @pytest.fixture def product_adapter(self, store: MemoryStore) -> DataclassAdapter[Product]: return DataclassAdapter[Product](key_value=store, dataclass_type=Product) + @pytest.fixture def product_list_adapter(self, store: MemoryStore) -> DataclassAdapter[list[Product]]: return DataclassAdapter[list[Product]](key_value=store, dataclass_type=list[Product]) + @pytest.fixture def user_with_address_adapter(self, store: MemoryStore) -> DataclassAdapter[UserWithAddress]: return DataclassAdapter[UserWithAddress](key_value=store, dataclass_type=UserWithAddress) + @pytest.fixture def order_adapter(self, store: MemoryStore) -> DataclassAdapter[Order]: return DataclassAdapter[Order](key_value=store, dataclass_type=Order) + def test_simple_adapter(self, user_adapter: DataclassAdapter[User]): """Test basic put/get/delete operations with a simple dataclass.""" user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) cached_user: User | None = user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_user == SAMPLE_USER - + assert user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + def test_simple_adapter_with_default(self, user_adapter: DataclassAdapter[User]): """Test default value handling.""" assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER - + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 + + assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot([SAMPLE_USER_2, SAMPLE_USER]) + - assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( - [SAMPLE_USER_2, SAMPLE_USER] - ) - - def test_simple_adapter_with_validation_error_ignore( - self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] - ): + def test_simple_adapter_with_validation_error_ignore(self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser]): """Test that validation errors return None when raise_on_validation_error is False.""" user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) - + # UpdatedUser requires is_admin field which doesn't exist in stored User updated_user = updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert updated_user is None + - def test_simple_adapter_with_validation_error_raise( - self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] - ): + def test_simple_adapter_with_validation_error_raise(self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser]): """Test that validation errors raise DeserializationError when raise_on_validation_error is True.""" user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) updated_user_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] with pytest.raises(DeserializationError): updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + def test_nested_dataclass(self, user_with_address_adapter: DataclassAdapter[UserWithAddress]): """Test that nested dataclasses are properly serialized and deserialized.""" @@ -148,56 +153,58 @@ def test_nested_dataclass(self, user_with_address_adapter: DataclassAdapter[User cached_user: UserWithAddress | None = user_with_address_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_user == SAMPLE_USER_WITH_ADDRESS assert cached_user is not None - assert cached_user.address.street == "123 Main St" + assert cached_user.address.street == '123 Main St' + def test_complex_adapter(self, order_adapter: DataclassAdapter[Order]): """Test complex dataclass with nested objects and TTL.""" order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) cached_order: Order | None = order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_order == SAMPLE_ORDER - + assert order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + def test_complex_adapter_with_list(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): """Test list dataclass serialization with proper wrapping.""" product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] - + # We need to ensure our memory store doesn't hold an entry with an array raw_collection = store._cache.get(TEST_COLLECTION) # pyright: ignore[reportPrivateUsage] assert raw_collection is not None - + raw_entry = raw_collection.get(TEST_KEY) assert raw_entry is not None assert isinstance(raw_entry.value, dict) - assert raw_entry.value == snapshot( - {"items": [{"name": "Widget", "price": 29.99, "quantity": 10}, {"name": "Widget", "price": 29.99, "quantity": 10}]} - ) - + assert raw_entry.value == snapshot({'items': [{'name': 'Widget', 'price': 29.99, 'quantity': 10}, {'name': 'Widget', 'price': 29.99, 'quantity': 10}]}) + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + def test_batch_operations(self, user_adapter: DataclassAdapter[User]): """Test batch put/get/delete operations.""" keys = [TEST_KEY, TEST_KEY_2] users = [SAMPLE_USER, SAMPLE_USER_2] - + # Test put_many user_adapter.put_many(collection=TEST_COLLECTION, keys=keys, values=users) - + # Test get_many cached_users = user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) assert cached_users == users - + # Test delete_many deleted_count = user_adapter.delete_many(collection=TEST_COLLECTION, keys=keys) assert deleted_count == 2 - + # Verify deletion cached_users_after_delete = user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) assert cached_users_after_delete == [None, None] + def test_ttl_operations(self, user_adapter: DataclassAdapter[User]): """Test TTL-related operations.""" @@ -207,28 +214,31 @@ def test_ttl_operations(self, user_adapter: DataclassAdapter[User]): assert user == SAMPLE_USER assert ttl is not None assert ttl > 0 - + # Test ttl_many user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value=SAMPLE_USER_2, ttl=20) ttl_results = user_adapter.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2]) assert len(ttl_results) == 2 assert ttl_results[0][0] == SAMPLE_USER assert ttl_results[1][0] == SAMPLE_USER_2 + def test_dataclass_validation_on_init(self, store: MemoryStore): """Test that non-dataclass types are rejected.""" - with pytest.raises(TypeError, match="is not a dataclass"): + with pytest.raises(TypeError, match='is not a dataclass'): DataclassAdapter[str](key_value=store, dataclass_type=str) # type: ignore[type-var] + def test_default_collection(self, store: MemoryStore): """Test that default collection is used when not specified.""" adapter = DataclassAdapter[User](key_value=store, dataclass_type=User, default_collection=TEST_COLLECTION) - + adapter.put(key=TEST_KEY, value=SAMPLE_USER) cached_user = adapter.get(key=TEST_KEY) assert cached_user == SAMPLE_USER - + assert adapter.delete(key=TEST_KEY) + def test_ttl_with_empty_list(self, product_list_adapter: DataclassAdapter[list[Product]]): """Test that TTL with empty list returns correctly (not None).""" @@ -237,20 +247,22 @@ def test_ttl_with_empty_list(self, product_list_adapter: DataclassAdapter[list[P assert value == [] assert ttl is not None assert ttl > 0 + def test_list_payload_missing_items_returns_none(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): """Test that list payload without 'items' wrapper returns None when raise_on_validation_error is False.""" # Manually insert malformed payload without the 'items' wrapper # The payload is a dict but without the expected 'items' key for list models - malformed_payload: dict[str, Any] = {"wrong": []} + malformed_payload: dict[str, Any] = {'wrong': []} store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + def test_list_payload_missing_items_raises(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): """Test that list payload without 'items' wrapper raises DeserializationError when configured.""" product_list_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] # Manually insert malformed payload without the 'items' wrapper - malformed_payload: dict[str, Any] = {"wrong": []} + malformed_payload: dict[str, Any] = {'wrong': []} store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) with pytest.raises(DeserializationError, match="missing 'items'"): product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py index e52f9e8f..58906c21 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py @@ -37,186 +37,183 @@ class Order(BaseModel): product: Product paid: bool - FIXED_CREATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=12, minute=0, second=0, tzinfo=timezone.utc) FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) -SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) -SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) -SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10, url=AnyHttpUrl(url="https://example.com")) +SAMPLE_USER: User = User(name='John Doe', email='john.doe@example.com', age=30) +SAMPLE_USER_2: User = User(name='Jane Doe', email='jane.doe@example.com', age=25) +SAMPLE_PRODUCT: Product = Product(name='Widget', price=29.99, quantity=10, url=AnyHttpUrl(url='https://example.com')) SAMPLE_ORDER: Order = Order(created_at=datetime.now(), updated_at=datetime.now(), user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) -TEST_COLLECTION: str = "test_collection" -TEST_KEY: str = "test_key" -TEST_KEY_2: str = "test_key_2" +TEST_COLLECTION: str = 'test_collection' +TEST_KEY: str = 'test_key' +TEST_KEY_2: str = 'test_key_2' def model_type_from_log_record(record: LogRecord) -> str: - if not hasattr(record, "model_type"): - msg = "Log record does not have a model_type attribute" + if not hasattr(record, 'model_type'): + msg = 'Log record does not have a model_type attribute' raise ValueError(msg) return record.model_type # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] def error_from_log_record(record: LogRecord) -> str: - if not hasattr(record, "error"): - msg = "Log record does not have an error attribute" + if not hasattr(record, 'error'): + msg = 'Log record does not have an error attribute' raise ValueError(msg) return record.error # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] def errors_from_log_record(record: LogRecord) -> list[str]: - if not hasattr(record, "errors"): - msg = "Log record does not have an errors attribute" + if not hasattr(record, 'errors'): + msg = 'Log record does not have an errors attribute' raise ValueError(msg) return record.errors # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] class TestPydanticAdapter: + @pytest.fixture def store(self) -> MemoryStore: return MemoryStore() + @pytest.fixture def user_adapter(self, store: MemoryStore) -> PydanticAdapter[User]: return PydanticAdapter[User](key_value=store, pydantic_model=User) + @pytest.fixture def updated_user_adapter(self, store: MemoryStore) -> PydanticAdapter[UpdatedUser]: return PydanticAdapter[UpdatedUser](key_value=store, pydantic_model=UpdatedUser) + @pytest.fixture def product_adapter(self, store: MemoryStore) -> PydanticAdapter[Product]: return PydanticAdapter[Product](key_value=store, pydantic_model=Product) + @pytest.fixture def product_list_adapter(self, store: MemoryStore) -> PydanticAdapter[list[Product]]: return PydanticAdapter[list[Product]](key_value=store, pydantic_model=list[Product]) + @pytest.fixture def order_adapter(self, store: MemoryStore) -> PydanticAdapter[Order]: return PydanticAdapter[Order](key_value=store, pydantic_model=Order) + def test_simple_adapter(self, user_adapter: PydanticAdapter[User]): user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) cached_user: User | None = user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_user == SAMPLE_USER - + assert user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]): assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER - + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 - - assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( - [SAMPLE_USER_2, SAMPLE_USER] - ) + + assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot([SAMPLE_USER_2, SAMPLE_USER]) + def test_simple_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]]): product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT]) cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] - + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + - def test_simple_adapter_with_validation_error_ignore( - self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] - ): + def test_simple_adapter_with_validation_error_ignore(self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser]): user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) - + updated_user = updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert updated_user is None + - def test_simple_adapter_with_validation_error_raise( - self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] - ): + def test_simple_adapter_with_validation_error_raise(self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser]): user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) updated_user_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] with pytest.raises(DeserializationError): updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) + def test_complex_adapter(self, order_adapter: PydanticAdapter[Order]): order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) == SAMPLE_ORDER - + assert order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + def test_complex_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore): product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] - + # We need to ensure our memory store doesnt hold an entry with an array raw_collection = store._cache.get(TEST_COLLECTION) # pyright: ignore[reportPrivateUsage] assert raw_collection is not None - + raw_entry = raw_collection.get(TEST_KEY) assert raw_entry is not None assert isinstance(raw_entry.value, dict) - assert raw_entry.value == snapshot( - { - "items": [ - {"name": "Widget", "price": 29.99, "quantity": 10, "url": "https://example.com/"}, - {"name": "Widget", "price": 29.99, "quantity": 10, "url": "https://example.com/"}, - ] - } - ) - + assert raw_entry.value == snapshot({'items': [{'name': 'Widget', 'price': 29.99, 'quantity': 10, 'url': 'https://example.com/'}, {'name': 'Widget', 'price': 29.99, 'quantity': 10, 'url': 'https://example.com/'}]}) + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None + - def test_validation_error_logging( - self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser], caplog: pytest.LogCaptureFixture - ): + def test_validation_error_logging(self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser], caplog: pytest.LogCaptureFixture): """Test that validation errors are logged when raise_on_validation_error=False.""" import logging - + # Store a User, then try to retrieve as UpdatedUser (missing is_admin field) user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) - + with caplog.at_level(logging.ERROR): updated_user = updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - + # Should return None due to validation failure assert updated_user is None - + # Check that an error was logged assert len(caplog.records) == 1 record = caplog.records[0] - assert record.levelname == "ERROR" - assert "Validation failed" in record.message - assert model_type_from_log_record(record) == "Pydantic model" - + assert record.levelname == 'ERROR' + assert 'Validation failed' in record.message + assert model_type_from_log_record(record) == 'Pydantic model' + errors = errors_from_log_record(record) assert len(errors) == 1 - assert "is_admin" in str(errors[0]) + assert 'is_admin' in str(errors[0]) + - def test_list_validation_error_logging( - self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore, caplog: pytest.LogCaptureFixture - ): + def test_list_validation_error_logging(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore, caplog: pytest.LogCaptureFixture): """Test that missing 'items' wrapper is logged for list models.""" import logging - + # Manually store invalid data (missing 'items' wrapper) - store.put(collection=TEST_COLLECTION, key=TEST_KEY, value={"invalid": "data"}) - + store.put(collection=TEST_COLLECTION, key=TEST_KEY, value={'invalid': 'data'}) + with caplog.at_level(logging.ERROR): result = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - + # Should return None due to missing 'items' wrapper assert result is None - + # Check that an error was logged assert len(caplog.records) == 1 record = caplog.records[0] - assert record.levelname == "ERROR" + assert record.levelname == 'ERROR' assert "Missing 'items' wrapper" in record.message - assert model_type_from_log_record(record) == "Pydantic model" + assert model_type_from_log_record(record) == 'Pydantic model' error = error_from_log_record(record) assert "missing 'items' wrapper" in str(error) diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py b/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py index 1653f1e4..1712dafc 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py @@ -19,50 +19,50 @@ def adapter(store: MemoryStore) -> RaiseOnMissingAdapter: def test_get(adapter: RaiseOnMissingAdapter): - adapter.put(collection="test", key="test", value={"test": "test"}) - assert adapter.get(collection="test", key="test") == {"test": "test"} + adapter.put(collection='test', key='test', value={'test': 'test'}) + assert adapter.get(collection='test', key='test') == {'test': 'test'} def test_get_missing(adapter: RaiseOnMissingAdapter): with pytest.raises(MissingKeyError): - _ = adapter.get(collection="test", key="test", raise_on_missing=True) + _ = adapter.get(collection='test', key='test', raise_on_missing=True) def test_get_many(adapter: RaiseOnMissingAdapter): - adapter.put(collection="test", key="test", value={"test": "test"}) - adapter.put(collection="test", key="test_2", value={"test": "test_2"}) - assert adapter.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] + adapter.put(collection='test', key='test', value={'test': 'test'}) + adapter.put(collection='test', key='test_2', value={'test': 'test_2'}) + assert adapter.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] def test_get_many_missing(adapter: RaiseOnMissingAdapter): - adapter.put(collection="test", key="test", value={"test": "test"}) + adapter.put(collection='test', key='test', value={'test': 'test'}) with pytest.raises(MissingKeyError): - _ = adapter.get_many(collection="test", keys=["test", "test_2"], raise_on_missing=True) + _ = adapter.get_many(collection='test', keys=['test', 'test_2'], raise_on_missing=True) def test_ttl(adapter: RaiseOnMissingAdapter): - adapter.put(collection="test", key="test", value={"test": "test"}, ttl=60) - (value, ttl) = adapter.ttl(collection="test", key="test") - assert value == {"test": "test"} + adapter.put(collection='test', key='test', value={'test': 'test'}, ttl=60) + (value, ttl) = adapter.ttl(collection='test', key='test') + assert value == {'test': 'test'} assert ttl is not None def test_ttl_missing(adapter: RaiseOnMissingAdapter): with pytest.raises(MissingKeyError): - _ = adapter.ttl(collection="test", key="test", raise_on_missing=True) + _ = adapter.ttl(collection='test', key='test', raise_on_missing=True) def test_ttl_many(adapter: RaiseOnMissingAdapter): - adapter.put(collection="test", key="test", value={"test": "test"}, ttl=60) - adapter.put(collection="test", key="test_2", value={"test": "test_2"}, ttl=120) - results = adapter.ttl_many(collection="test", keys=["test", "test_2"]) - assert results[0][0] == {"test": "test"} + adapter.put(collection='test', key='test', value={'test': 'test'}, ttl=60) + adapter.put(collection='test', key='test_2', value={'test': 'test_2'}, ttl=120) + results = adapter.ttl_many(collection='test', keys=['test', 'test_2']) + assert results[0][0] == {'test': 'test'} assert results[0][1] is not None - assert results[1][0] == {"test": "test_2"} + assert results[1][0] == {'test': 'test_2'} assert results[1][1] is not None def test_ttl_many_missing(adapter: RaiseOnMissingAdapter): - adapter.put(collection="test", key="test", value={"test": "test"}, ttl=60) + adapter.put(collection='test', key='test', value={'test': 'test'}, ttl=60) with pytest.raises(MissingKeyError): - _ = adapter.ttl_many(collection="test", keys=["test", "test_2"], raise_on_missing=True) + _ = adapter.ttl_many(collection='test', keys=['test', 'test_2'], raise_on_missing=True) diff --git a/key-value/key-value-sync/tests/code_gen/cases.py b/key-value/key-value-sync/tests/code_gen/cases.py index 0fffeebe..a1aa3250 100644 --- a/key-value/key-value-sync/tests/code_gen/cases.py +++ b/key-value/key-value-sync/tests/code_gen/cases.py @@ -7,57 +7,19 @@ FIXED_DATETIME = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) FIXED_TIME = FIXED_DATETIME.time() -LARGE_STRING: str = "a" * 10000 # 10KB -LARGE_INT: int = 1 * 10**18 # 18 digits -LARGE_FLOAT: float = 1.0 * 10**63 # 63 digits +LARGE_STRING: str = 'a' * 10000 # 10KB +LARGE_INT: int = 1 * 10 ** 18 # 18 digits +LARGE_FLOAT: float = 1.0 * 10 ** 63 # 63 digits -SIMPLE_CASE: dict[str, Any] = { - "key_1": "value_1", - "key_2": 1, - "key_3": 1.0, - "key_4": [1, 2, 3], - "key_5": {"nested": "value"}, - "key_6": True, - "key_7": False, - "key_8": None, -} +SIMPLE_CASE: dict[str, Any] = {'key_1': 'value_1', 'key_2': 1, 'key_3': 1.0, 'key_4': [1, 2, 3], 'key_5': {'nested': 'value'}, 'key_6': True, 'key_7': False, 'key_8': None} SIMPLE_CASE_JSON: str = '{"key_1": "value_1", "key_2": 1, "key_3": 1.0, "key_4": [1, 2, 3], "key_5": {"nested": "value"}, "key_6": true, "key_7": false, "key_8": null}' # ({"key": (1, 2, 3)}, '{"key": [1, 2, 3]}'), -DICTIONARY_TO_JSON_TEST_CASES: list[tuple[dict[str, Any], str]] = [ - ({"key": "value"}, '{"key": "value"}'), - ({"key": 1}, '{"key": 1}'), - ({"key": 1.0}, '{"key": 1.0}'), - ({"key": [1, 2, 3]}, '{"key": [1, 2, 3]}'), - ({"key": {"nested": "value"}}, '{"key": {"nested": "value"}}'), - ({"key": True}, '{"key": true}'), - ({"key": False}, '{"key": false}'), - ({"key": None}, '{"key": null}'), - ( - {"key": {"int": 1, "float": 1.0, "list": [1, 2, 3], "dict": {"nested": "value"}, "bool": True, "null": None}}, - '{"key": {"int": 1, "float": 1.0, "list": [1, 2, 3], "dict": {"nested": "value"}, "bool": true, "null": null}}', - ), - ({"key": LARGE_STRING}, f'{{"key": "{LARGE_STRING}"}}'), - ({"key": LARGE_INT}, f'{{"key": {LARGE_INT}}}'), - ({"key": LARGE_FLOAT}, f'{{"key": {LARGE_FLOAT}}}'), -] +DICTIONARY_TO_JSON_TEST_CASES: list[tuple[dict[str, Any], str]] = [({'key': 'value'}, '{"key": "value"}'), ({'key': 1}, '{"key": 1}'), ({'key': 1.0}, '{"key": 1.0}'), ({'key': [1, 2, 3]}, '{"key": [1, 2, 3]}'), ({'key': {'nested': 'value'}}, '{"key": {"nested": "value"}}'), ({'key': True}, '{"key": true}'), ({'key': False}, '{"key": false}'), ({'key': None}, '{"key": null}'), ({'key': {'int': 1, 'float': 1.0, 'list': [1, 2, 3], 'dict': {'nested': 'value'}, 'bool': True, 'null': None}}, '{"key": {"int": 1, "float": 1.0, "list": [1, 2, 3], "dict": {"nested": "value"}, "bool": true, "null": null}}'), ({'key': LARGE_STRING}, f'{{"key": "{LARGE_STRING}"}}'), ({'key': LARGE_INT}, f'{{"key": {LARGE_INT}}}'), ({'key': LARGE_FLOAT}, f'{{"key": {LARGE_FLOAT}}}')] # "tuple", -DICTIONARY_TO_JSON_TEST_CASES_NAMES: list[str] = [ - "string", - "int", - "float", - "list", - "dict", - "bool-true", - "bool-false", - "null", - "dict-nested", - "large-string", - "large-int", - "large-float", -] +DICTIONARY_TO_JSON_TEST_CASES_NAMES: list[str] = ['string', 'int', 'float', 'list', 'dict', 'bool-true', 'bool-false', 'null', 'dict-nested', 'large-string', 'large-int', 'large-float'] OBJECT_TEST_CASES: list[dict[str, Any]] = [test_case[0] for test_case in DICTIONARY_TO_JSON_TEST_CASES] diff --git a/key-value/key-value-sync/tests/code_gen/conftest.py b/key-value/key-value-sync/tests/code_gen/conftest.py index 3d7b2051..f2b9505b 100644 --- a/key-value/key-value-sync/tests/code_gen/conftest.py +++ b/key-value/key-value-sync/tests/code_gen/conftest.py @@ -22,10 +22,11 @@ @contextmanager def try_import() -> Iterator[Callable[[], bool]]: import_success = False + def check_import() -> bool: return import_success - + try: yield check_import except ImportError: @@ -43,70 +44,70 @@ def docker_client() -> DockerClient: return get_docker_client() -def docker_logs(name: str, print_logs: bool = False, raise_on_error: bool = False, log_level: int = logging.INFO) -> list[str]: +def docker_logs(name: str, print_logs: bool=False, raise_on_error: bool=False, log_level: int=logging.INFO) -> list[str]: client = get_docker_client() try: - logs: list[str] = client.containers.get(name).logs().decode("utf-8").splitlines() + logs: list[str] = client.containers.get(name).logs().decode('utf-8').splitlines() except Exception: - logger.info(f"Container {name} failed to get logs") + logger.info(f'Container {name} failed to get logs') if raise_on_error: raise return [] - + if print_logs: - logger.info(f"Container {name} logs:") + logger.info(f'Container {name} logs:') for log in logs: logger.log(log_level, log) - + return logs -def docker_get(name: str, raise_on_not_found: bool = False) -> Container | None: +def docker_get(name: str, raise_on_not_found: bool=False) -> Container | None: from docker.errors import NotFound - + client = get_docker_client() try: return client.containers.get(name) except NotFound: - logger.info(f"Container {name} failed to get") + logger.info(f'Container {name} failed to get') if raise_on_not_found: raise return None -def docker_pull(image: str, raise_on_error: bool = False) -> bool: - logger.info(f"Pulling image {image}") +def docker_pull(image: str, raise_on_error: bool=False) -> bool: + logger.info(f'Pulling image {image}') client = get_docker_client() try: client.images.pull(image) except Exception: - logger.exception(f"Image {image} failed to pull") + logger.exception(f'Image {image} failed to pull') if raise_on_error: raise return False return True -def docker_stop(name: str, raise_on_error: bool = False) -> bool: - logger.info(f"Stopping container {name}") - +def docker_stop(name: str, raise_on_error: bool=False) -> bool: + logger.info(f'Stopping container {name}') + if not (container := docker_get(name=name, raise_on_not_found=False)): return False - + try: container.stop() except Exception: - logger.info(f"Container {name} failed to stop") + logger.info(f'Container {name} failed to stop') if raise_on_error: raise return False - - logger.info(f"Container {name} stopped") + + logger.info(f'Container {name} stopped') return True -def docker_wait_container_gone(name: str, max_tries: int = 10, wait_time: float = 1.0) -> bool: - logger.info(f"Waiting for container {name} to be gone") +def docker_wait_container_gone(name: str, max_tries: int=10, wait_time: float=1.0) -> bool: + logger.info(f'Waiting for container {name} to be gone') count = 0 while count < max_tries: if not docker_get(name=name, raise_on_not_found=False): @@ -116,53 +117,51 @@ def docker_wait_container_gone(name: str, max_tries: int = 10, wait_time: float return False -def docker_rm(name: str, raise_on_error: bool = False) -> bool: - logger.info(f"Removing container {name}") - +def docker_rm(name: str, raise_on_error: bool=False) -> bool: + logger.info(f'Removing container {name}') + if not (container := docker_get(name=name, raise_on_not_found=False)): return False - + try: container.remove() except Exception: - logger.info(f"Container {name} failed to remove") + logger.info(f'Container {name} failed to remove') if raise_on_error: raise return False - logger.info(f"Container {name} removed") + logger.info(f'Container {name} removed') return True -def docker_run(name: str, image: str, ports: dict[str, int], environment: dict[str, str], raise_on_error: bool = False) -> bool: - logger.info(f"Running container {name} with image {image} and ports {ports}") +def docker_run(name: str, image: str, ports: dict[str, int], environment: dict[str, str], raise_on_error: bool=False) -> bool: + logger.info(f'Running container {name} with image {image} and ports {ports}') client = get_docker_client() try: client.containers.run(name=name, image=image, ports=ports, environment=environment, detach=True) except Exception: - logger.exception(f"Container {name} failed to run") + logger.exception(f'Container {name} failed to run') if raise_on_error: raise return False - logger.info(f"Container {name} running") + logger.info(f'Container {name} running') return True @contextmanager -def docker_container( - name: str, image: str, ports: dict[str, int], environment: dict[str, str] | None = None, raise_on_error: bool = True -) -> Iterator[None]: - logger.info(f"Creating container {name} with image {image} and ports {ports}") +def docker_container(name: str, image: str, ports: dict[str, int], environment: dict[str, str] | None=None, raise_on_error: bool=True) -> Iterator[None]: + logger.info(f'Creating container {name} with image {image} and ports {ports}') try: docker_pull(image=image, raise_on_error=True) docker_stop(name=name, raise_on_error=False) docker_rm(name=name, raise_on_error=False) docker_wait_container_gone(name=name, max_tries=10, wait_time=1.0) docker_run(name=name, image=image, ports=ports, environment=environment or {}, raise_on_error=True) - logger.info(f"Container {name} created") + logger.info(f'Container {name} created') yield docker_logs(name, print_logs=True, raise_on_error=False) except Exception: - logger.info(f"Creating container {name} failed") + logger.info(f'Creating container {name} failed') docker_logs(name, print_logs=True, raise_on_error=False, log_level=logging.ERROR) if raise_on_error: raise @@ -171,8 +170,8 @@ def docker_container( docker_stop(name, raise_on_error=False) docker_rm(name, raise_on_error=False) docker_wait_container_gone(name=name, max_tries=10, wait_time=1.0) - - logger.info(f"Container {name} stopped and removed") + + logger.info(f'Container {name} stopped and removed') return @@ -190,7 +189,7 @@ def running_in_event_loop() -> bool: def detect_docker() -> bool: try: - result = subprocess.run(["docker", "ps"], check=False, capture_output=True, text=True) # noqa: S607 + result = subprocess.run(['docker', 'ps'], check=False, capture_output=True, text=True) # noqa: S607 except Exception: return False else: @@ -198,19 +197,19 @@ def detect_docker() -> bool: def detect_on_ci() -> bool: - return os.getenv("CI", "false") == "true" + return os.getenv('CI', 'false') == 'true' def detect_on_windows() -> bool: - return platform.system() == "Windows" + return platform.system() == 'Windows' def detect_on_macos() -> bool: - return platform.system() == "Darwin" + return platform.system() == 'Darwin' def detect_on_linux() -> bool: - return platform.system() == "Linux" + return platform.system() == 'Linux' def should_run_docker_tests() -> bool: diff --git a/key-value/key-value-sync/tests/code_gen/protocols/test_types.py b/key-value/key-value-sync/tests/code_gen/protocols/test_types.py index 2d4abd61..b883d85a 100644 --- a/key-value/key-value-sync/tests/code_gen/protocols/test_types.py +++ b/key-value/key-value-sync/tests/code_gen/protocols/test_types.py @@ -6,15 +6,16 @@ def test_key_value_protocol(): - def test_protocol(key_value: KeyValue): - assert key_value.get(collection="test", key="test") is None - key_value.put(collection="test", key="test", value={"test": "test"}) - assert key_value.delete(collection="test", key="test") - key_value.put(collection="test", key="test_2", value={"test": "test"}) + def test_protocol(key_value: KeyValue): + assert key_value.get(collection='test', key='test') is None + key_value.put(collection='test', key='test', value={'test': 'test'}) + assert key_value.delete(collection='test', key='test') + key_value.put(collection='test', key='test_2', value={'test': 'test'}) + memory_store = MemoryStore() - + test_protocol(key_value=memory_store) - - assert memory_store.get(collection="test", key="test") is None - assert memory_store.get(collection="test", key="test_2") == {"test": "test"} + + assert memory_store.get(collection='test', key='test') is None + assert memory_store.get(collection='test', key='test_2') == {'test': 'test'} diff --git a/key-value/key-value-sync/tests/code_gen/stores/base.py b/key-value/key-value-sync/tests/code_gen/stores/base.py index 8f191c55..23fc9bb6 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/base.py +++ b/key-value/key-value-sync/tests/code_gen/stores/base.py @@ -20,253 +20,292 @@ class BaseStoreTests(ABC): + def eventually_consistent(self) -> None: # noqa: B027 - "Subclasses can override this to wait for eventually consistent operations." + 'Subclasses can override this to wait for eventually consistent operations.' + @pytest.fixture @abstractmethod - def store(self) -> BaseStore | Generator[BaseStore, None, None]: ... + def store(self) -> BaseStore | Generator[BaseStore, None, None]: + ... + @pytest.mark.timeout(60) def test_store(self, store: BaseStore): """Tests that the store is a valid KeyValueProtocol.""" assert isinstance(store, KeyValueProtocol) is True + def test_empty_get(self, store: BaseStore): """Tests that the get method returns None from an empty store.""" - assert store.get(collection="test", key="test") is None + assert store.get(collection='test', key='test') is None + def test_empty_put(self, store: BaseStore): """Tests that the put method does not raise an exception when called on a new store.""" - store.put(collection="test", key="test", value={"test": "test"}) + store.put(collection='test', key='test', value={'test': 'test'}) + def test_empty_ttl(self, store: BaseStore): """Tests that the ttl method returns None from an empty store.""" - ttl = store.ttl(collection="test", key="test") + ttl = store.ttl(collection='test', key='test') assert ttl == (None, None) + def test_put_serialization_errors(self, store: BaseStore): """Tests that the put method raises SerializationError for non-JSON-serializable Pydantic types.""" with pytest.raises(SerializationError): - store.put(collection="test", key="test", value={"test": AnyHttpUrl("https://test.com")}) + store.put(collection='test', key='test', value={'test': AnyHttpUrl('https://test.com')}) + def test_get_put_get(self, store: BaseStore): - assert store.get(collection="test", key="test") is None - store.put(collection="test", key="test", value={"test": "test"}) - assert store.get(collection="test", key="test") == {"test": "test"} + assert store.get(collection='test', key='test') is None + store.put(collection='test', key='test', value={'test': 'test'}) + assert store.get(collection='test', key='test') == {'test': 'test'} + @PositiveCases.parametrize(cases=SIMPLE_CASES) def test_models_put_get(self, store: BaseStore, data: dict[str, Any], json: str, round_trip: dict[str, Any]): - store.put(collection="test", key="test", value=data) - retrieved_data = store.get(collection="test", key="test") + store.put(collection='test', key='test', value=data) + retrieved_data = store.get(collection='test', key='test') assert retrieved_data is not None assert retrieved_data == round_trip + @NegativeCases.parametrize(cases=NEGATIVE_SIMPLE_CASES) def test_negative_models_put_get(self, store: BaseStore, data: dict[str, Any], error: type[Exception]): with pytest.raises(error): - store.put(collection="test", key="test", value=data) + store.put(collection='test', key='test', value=data) + @PositiveCases.parametrize(cases=[LARGE_DATA_CASES]) def test_get_large_put_get(self, store: BaseStore, data: dict[str, Any], json: str, round_trip: dict[str, Any]): - store.put(collection="test", key="test", value=data) - assert store.get(collection="test", key="test") == round_trip + store.put(collection='test', key='test', value=data) + assert store.get(collection='test', key='test') == round_trip + def test_put_many_get(self, store: BaseStore): - store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) - assert store.get(collection="test", key="test") == {"test": "test"} - assert store.get(collection="test", key="test_2") == {"test": "test_2"} + store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) + assert store.get(collection='test', key='test') == {'test': 'test'} + assert store.get(collection='test', key='test_2') == {'test': 'test_2'} + def test_put_many_get_many(self, store: BaseStore): - store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) - assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] + store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) + assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] + def test_put_put_get_many(self, store: BaseStore): - store.put(collection="test", key="test", value={"test": "test"}) - store.put(collection="test", key="test_2", value={"test": "test_2"}) - assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] + store.put(collection='test', key='test', value={'test': 'test'}) + store.put(collection='test', key='test_2', value={'test': 'test_2'}) + assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] + def test_put_put_get_many_missing_one(self, store: BaseStore): - store.put(collection="test", key="test", value={"test": "test"}) - store.put(collection="test", key="test_2", value={"test": "test_2"}) - assert store.get_many(collection="test", keys=["test", "test_2", "test_3"]) == [{"test": "test"}, {"test": "test_2"}, None] + store.put(collection='test', key='test', value={'test': 'test'}) + store.put(collection='test', key='test_2', value={'test': 'test_2'}) + assert store.get_many(collection='test', keys=['test', 'test_2', 'test_3']) == [{'test': 'test'}, {'test': 'test_2'}, None] + def test_put_get_delete_get(self, store: BaseStore): - store.put(collection="test", key="test", value={"test": "test"}) - assert store.get(collection="test", key="test") == {"test": "test"} - assert store.delete(collection="test", key="test") - assert store.get(collection="test", key="test") is None + store.put(collection='test', key='test', value={'test': 'test'}) + assert store.get(collection='test', key='test') == {'test': 'test'} + assert store.delete(collection='test', key='test') + assert store.get(collection='test', key='test') is None + def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): - store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) - assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] - assert store.delete_many(collection="test", keys=["test", "test_2"]) == 2 - assert store.get_many(collection="test", keys=["test", "test_2"]) == [None, None] + store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) + assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] + assert store.delete_many(collection='test', keys=['test', 'test_2']) == 2 + assert store.get_many(collection='test', keys=['test', 'test_2']) == [None, None] + def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): - store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) - assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] - assert store.delete_many(collection="test", keys=["test", "test_2"]) == 2 - assert store.get_many(collection="test", keys=["test", "test_2"]) == [None, None] + store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) + assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] + assert store.delete_many(collection='test', keys=['test', 'test_2']) == 2 + assert store.get_many(collection='test', keys=['test', 'test_2']) == [None, None] + def test_put_many_tuple_get_many(self, store: BaseStore): - store.put_many(collection="test", keys=["test", "test_2"], values=({"test": "test"}, {"test": "test_2"})) - assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] + store.put_many(collection='test', keys=['test', 'test_2'], values=({'test': 'test'}, {'test': 'test_2'})) + assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] + def test_delete(self, store: BaseStore): - assert store.delete(collection="test", key="test") is False + assert store.delete(collection='test', key='test') is False + def test_put_delete_delete(self, store: BaseStore): - store.put(collection="test", key="test", value={"test": "test"}) - assert store.delete(collection="test", key="test") - assert store.delete(collection="test", key="test") is False + store.put(collection='test', key='test', value={'test': 'test'}) + assert store.delete(collection='test', key='test') + assert store.delete(collection='test', key='test') is False + def test_delete_many(self, store: BaseStore): - assert store.delete_many(collection="test", keys=["test", "test_2"]) == 0 + assert store.delete_many(collection='test', keys=['test', 'test_2']) == 0 + def test_put_delete_many(self, store: BaseStore): - store.put(collection="test", key="test", value={"test": "test"}) - assert store.delete_many(collection="test", keys=["test", "test_2"]) == 1 + store.put(collection='test', key='test', value={'test': 'test'}) + assert store.delete_many(collection='test', keys=['test', 'test_2']) == 1 + def test_delete_many_delete_many(self, store: BaseStore): - store.put(collection="test", key="test", value={"test": "test"}) - assert store.delete_many(collection="test", keys=["test", "test_2"]) == 1 - assert store.delete_many(collection="test", keys=["test", "test_2"]) == 0 + store.put(collection='test', key='test', value={'test': 'test'}) + assert store.delete_many(collection='test', keys=['test', 'test_2']) == 1 + assert store.delete_many(collection='test', keys=['test', 'test_2']) == 0 + def test_get_put_get_delete_get(self, store: BaseStore): """Tests that the get, put, delete, and get methods work together to store and retrieve a value from an empty store.""" - - assert store.get(collection="test", key="test") is None - - store.put(collection="test", key="test", value={"test": "test"}) - - assert store.get(collection="test", key="test") == {"test": "test"} - - assert store.delete(collection="test", key="test") - - assert store.get(collection="test", key="test") is None + + assert store.get(collection='test', key='test') is None + + store.put(collection='test', key='test', value={'test': 'test'}) + + assert store.get(collection='test', key='test') == {'test': 'test'} + + assert store.delete(collection='test', key='test') + + assert store.get(collection='test', key='test') is None + def test_get_put_get_put_delete_get(self, store: BaseStore): """Tests that the get, put, get, put, delete, and get methods work together to store and retrieve a value from an empty store.""" - store.put(collection="test", key="test", value={"test": "test"}) - assert store.get(collection="test", key="test") == {"test": "test"} - - store.put(collection="test", key="test", value={"test": "test_2"}) - - assert store.get(collection="test", key="test") == {"test": "test_2"} - assert store.delete(collection="test", key="test") - assert store.get(collection="test", key="test") is None + store.put(collection='test', key='test', value={'test': 'test'}) + assert store.get(collection='test', key='test') == {'test': 'test'} + + store.put(collection='test', key='test', value={'test': 'test_2'}) + + assert store.get(collection='test', key='test') == {'test': 'test_2'} + assert store.delete(collection='test', key='test') + assert store.get(collection='test', key='test') is None + def test_put_many_delete_delete_get_many(self, store: BaseStore): - store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) - assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] - assert store.delete(collection="test", key="test") - assert store.delete(collection="test", key="test_2") - assert store.get_many(collection="test", keys=["test", "test_2"]) == [None, None] + store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) + assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] + assert store.delete(collection='test', key='test') + assert store.delete(collection='test', key='test_2') + assert store.get_many(collection='test', keys=['test', 'test_2']) == [None, None] + def test_put_ttl_get_ttl(self, store: BaseStore): """Tests that the put and get ttl methods work together to store and retrieve a ttl from an empty store.""" - store.put(collection="test", key="test", value={"test": "test"}, ttl=100) - (value, ttl) = store.ttl(collection="test", key="test") - - assert value == {"test": "test"} + store.put(collection='test', key='test', value={'test': 'test'}, ttl=100) + (value, ttl) = store.ttl(collection='test', key='test') + + assert value == {'test': 'test'} assert ttl is not None assert ttl == IsFloat(approx=100) + def test_negative_ttl(self, store: BaseStore): """Tests that a negative ttl will return None when getting the key.""" with pytest.raises(InvalidTTLError): - store.put(collection="test", key="test", value={"test": "test"}, ttl=-100) + store.put(collection='test', key='test', value={'test': 'test'}, ttl=-100) + @pytest.mark.timeout(10) def test_put_expired_get_none(self, store: BaseStore): """Tests that a put call with a negative ttl will return None when getting the key.""" - store.put(collection="test_collection", key="test_key", value={"test": "test"}, ttl=2) - assert store.get(collection="test_collection", key="test_key") is not None + store.put(collection='test_collection', key='test_key', value={'test': 'test'}, ttl=2) + assert store.get(collection='test_collection', key='test_key') is not None sleep(seconds=1) - + for _ in range(8): sleep(seconds=0.25) - if store.get(collection="test_collection", key="test_key") is None: + if store.get(collection='test_collection', key='test_key') is None: # pass the test return - - pytest.fail("put_expired_get_none test failed, entry did not expire") + + pytest.fail('put_expired_get_none test failed, entry did not expire') + def test_long_collection_name(self, store: BaseStore): """Tests that a long collection name will not raise an error.""" - store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"}) - assert store.get(collection="test_collection" * 100, key="test_key") == {"test": "test"} + store.put(collection='test_collection' * 100, key='test_key', value={'test': 'test'}) + assert store.get(collection='test_collection' * 100, key='test_key') == {'test': 'test'} + def test_special_characters_in_collection_name(self, store: BaseStore): """Tests that a special characters in the collection name will not raise an error.""" - store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) - assert store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + store.put(collection='test_collection!@#$%^&*()', key='test_key', value={'test': 'test'}) + assert store.get(collection='test_collection!@#$%^&*()', key='test_key') == {'test': 'test'} + def test_long_key_name(self, store: BaseStore): """Tests that a long key name will not raise an error.""" - store.put(collection="test_collection", key="test_key" * 100, value={"test": "test"}) - assert store.get(collection="test_collection", key="test_key" * 100) == {"test": "test"} + store.put(collection='test_collection', key='test_key' * 100, value={'test': 'test'}) + assert store.get(collection='test_collection', key='test_key' * 100) == {'test': 'test'} + def test_special_characters_in_key_name(self, store: BaseStore): """Tests that a special characters in the key name will not raise an error.""" - store.put(collection="test_collection", key="test_key!@#$%^&*()", value={"test": "test"}) - assert store.get(collection="test_collection", key="test_key!@#$%^&*()") == {"test": "test"} + store.put(collection='test_collection', key='test_key!@#$%^&*()', value={'test': 'test'}) + assert store.get(collection='test_collection', key='test_key!@#$%^&*()') == {'test': 'test'} + @pytest.mark.timeout(20) def test_not_unbounded(self, store: BaseStore): """Tests that the store is not unbounded.""" - + for i in range(1000): - value = hashlib.sha256(f"test_{i}".encode()).hexdigest() - store.put(collection="test_collection", key=f"test_key_{i}", value={"test": value}) - - assert store.get(collection="test_collection", key="test_key_0") is None - assert store.get(collection="test_collection", key="test_key_999") is not None - - @pytest.mark.skipif(condition=not running_in_event_loop(), reason="Cannot run concurrent operations outside of event loop") + value = hashlib.sha256(f'test_{i}'.encode()).hexdigest() + store.put(collection='test_collection', key=f'test_key_{i}', value={'test': value}) + + assert store.get(collection='test_collection', key='test_key_0') is None + assert store.get(collection='test_collection', key='test_key_999') is not None + + + @pytest.mark.skipif(condition=not running_in_event_loop(), reason='Cannot run concurrent operations outside of event loop') def test_concurrent_operations(self, store: BaseStore): """Tests that the store can handle concurrent operations.""" + def worker(store: BaseStore, worker_id: int): for i in range(10): - assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") is None - - store.put(collection="test_collection", key=f"test_{worker_id}_{i}", value={"test": f"test_{i}"}) - assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") == {"test": f"test_{i}"} - - store.put(collection="test_collection", key=f"test_{worker_id}_{i}", value={"test": f"test_{i}_2"}) - assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") == {"test": f"test_{i}_2"} - - assert store.delete(collection="test_collection", key=f"test_{worker_id}_{i}") - assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") is None - + assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') is None + + store.put(collection='test_collection', key=f'test_{worker_id}_{i}', value={'test': f'test_{i}'}) + assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') == {'test': f'test_{i}'} + + store.put(collection='test_collection', key=f'test_{worker_id}_{i}', value={'test': f'test_{i}_2'}) + assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') == {'test': f'test_{i}_2'} + + assert store.delete(collection='test_collection', key=f'test_{worker_id}_{i}') + assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') is None + _ = gather(*[worker(store, worker_id) for worker_id in range(5)]) + @pytest.mark.timeout(15) def test_minimum_put_many_get_many_performance(self, store: BaseStore): """Tests that the store meets minimum performance requirements.""" - keys = [f"test_{i}" for i in range(10)] - values = [{"test": f"test_{i}"} for i in range(10)] - store.put_many(collection="test_collection", keys=keys, values=values) - assert store.get_many(collection="test_collection", keys=keys) == values + keys = [f'test_{i}' for i in range(10)] + values = [{'test': f'test_{i}'} for i in range(10)] + store.put_many(collection='test_collection', keys=keys, values=values) + assert store.get_many(collection='test_collection', keys=keys) == values + @pytest.mark.timeout(15) def test_minimum_put_many_delete_many_performance(self, store: BaseStore): """Tests that the store meets minimum performance requirements.""" - keys = [f"test_{i}" for i in range(10)] - values = [{"test": f"test_{i}"} for i in range(10)] - store.put_many(collection="test_collection", keys=keys, values=values) - assert store.delete_many(collection="test_collection", keys=keys) == 10 + keys = [f'test_{i}' for i in range(10)] + values = [{'test': f'test_{i}'} for i in range(10)] + store.put_many(collection='test_collection', keys=keys, values=values) + assert store.delete_many(collection='test_collection', keys=keys) == 10 class ContextManagerStoreTestMixin: - @pytest.fixture(params=[True, False], ids=["with_ctx_manager", "no_ctx_manager"], autouse=True) - def enter_exit_store( - self, request: pytest.FixtureRequest, store: BaseContextManagerStore - ) -> Generator[BaseContextManagerStore, None, None]: - context_manager = request.param # pyright: ignore[reportAny] + @pytest.fixture(params=[True, False], ids=['with_ctx_manager', 'no_ctx_manager'], autouse=True) + def enter_exit_store(self, request: pytest.FixtureRequest, store: BaseContextManagerStore) -> Generator[BaseContextManagerStore, None, None]: + context_manager = request.param # pyright: ignore[reportAny] + if context_manager: with store: yield store diff --git a/key-value/key-value-sync/tests/code_gen/stores/conftest.py b/key-value/key-value-sync/tests/code_gen/stores/conftest.py index 114e8c57..322fdaa7 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/conftest.py +++ b/key-value/key-value-sync/tests/code_gen/stores/conftest.py @@ -21,5 +21,5 @@ def now_plus(seconds: int) -> datetime: return now() + timedelta(seconds=seconds) -def is_around(value: float, delta: float = 1) -> bool: +def is_around(value: float, delta: float=1) -> bool: return value - delta < value < value + delta diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py index 994738a2..89abc72a 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py @@ -14,14 +14,16 @@ class TestDiskStore(ContextManagerStoreTestMixin, BaseStoreTests): - @pytest.fixture(scope="session") + + @pytest.fixture(scope='session') def disk_store(self) -> Generator[DiskStore, None, None]: with tempfile.TemporaryDirectory() as temp_dir: yield DiskStore(directory=temp_dir, max_size=TEST_SIZE_LIMIT) + @override @pytest.fixture def store(self, disk_store: DiskStore) -> DiskStore: disk_store._cache.clear() # pyright: ignore[reportPrivateUsage] - + return disk_store diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py index e6341075..b3944761 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py @@ -15,15 +15,17 @@ class TestMultiDiskStore(ContextManagerStoreTestMixin, BaseStoreTests): - @pytest.fixture(scope="session") + + @pytest.fixture(scope='session') def multi_disk_store(self) -> Generator[MultiDiskStore, None, None]: with tempfile.TemporaryDirectory() as temp_dir: yield MultiDiskStore(base_directory=Path(temp_dir), max_size=TEST_SIZE_LIMIT) + @override @pytest.fixture def store(self, multi_disk_store: MultiDiskStore) -> MultiDiskStore: for collection in multi_disk_store._cache: # pyright: ignore[reportPrivateUsage] multi_disk_store._cache[collection].clear() # pyright: ignore[reportPrivateUsage] - + return multi_disk_store diff --git a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py index 69851a03..033ef522 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py +++ b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py @@ -19,15 +19,15 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin TEST_SIZE_LIMIT = 1 * 1024 * 1024 # 1MB -ES_HOST = "localhost" +ES_HOST = 'localhost' ES_PORT = 9200 -ES_URL = f"http://{ES_HOST}:{ES_PORT}" +ES_URL = f'http://{ES_HOST}:{ES_PORT}' ES_CONTAINER_PORT = 9200 WAIT_FOR_ELASTICSEARCH_TIMEOUT = 30 -# Released Apr 2025 + # Released Apr 2025 # Released Oct 2025 -ELASTICSEARCH_VERSIONS_TO_TEST = ["9.0.0", "9.2.0"] +ELASTICSEARCH_VERSIONS_TO_TEST = ['9.0.0', '9.2.0'] def get_elasticsearch_client() -> Elasticsearch: @@ -36,13 +36,13 @@ def get_elasticsearch_client() -> Elasticsearch: def ping_elasticsearch() -> bool: es_client: Elasticsearch = get_elasticsearch_client() - + with es_client: return es_client.ping() def cleanup_elasticsearch_indices(elasticsearch_client: Elasticsearch): - indices = elasticsearch_client.options(ignore_status=404).indices.get(index="kv-store-e2e-test-*") + indices = elasticsearch_client.options(ignore_status=404).indices.get(index='kv-store-e2e-test-*') for index in indices: _ = elasticsearch_client.options(ignore_status=404).indices.delete(index=index) @@ -54,22 +54,14 @@ class ElasticsearchFailedToStartError(Exception): def test_managed_entry_document_conversion(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry) - - assert document == snapshot( - { - "collection": "test_collection", - "key": "test_key", - "value": {"string": '{"test": "test"}'}, - "created_at": "2025-01-01T00:00:00+00:00", - "expires_at": "2025-01-01T00:00:10+00:00", - } - ) - + + managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(collection='test_collection', key='test_key', managed_entry=managed_entry) + + assert document == snapshot({'collection': 'test_collection', 'key': 'test_key', 'value': {'string': '{"test": "test"}'}, 'created_at': '2025-01-01T00:00:00+00:00', 'expires_at': '2025-01-01T00:00:10+00:00'}) + round_trip_managed_entry = source_to_managed_entry(source=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -79,22 +71,14 @@ def test_managed_entry_document_conversion(): def test_managed_entry_document_conversion_native_storage(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry, native_storage=True) - - assert document == snapshot( - { - "collection": "test_collection", - "key": "test_key", - "value": {"flattened": {"test": "test"}}, - "created_at": "2025-01-01T00:00:00+00:00", - "expires_at": "2025-01-01T00:00:10+00:00", - } - ) - + + managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(collection='test_collection', key='test_key', managed_entry=managed_entry, native_storage=True) + + assert document == snapshot({'collection': 'test_collection', 'key': 'test_key', 'value': {'flattened': {'test': 'test'}}, 'created_at': '2025-01-01T00:00:00+00:00', 'expires_at': '2025-01-01T00:00:10+00:00'}) + round_trip_managed_entry = source_to_managed_entry(source=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -102,160 +86,133 @@ def test_managed_entry_document_conversion_native_storage(): class BaseTestElasticsearchStore(ContextManagerStoreTestMixin, BaseStoreTests): - @pytest.fixture(autouse=True, scope="session", params=ELASTICSEARCH_VERSIONS_TO_TEST) + + @pytest.fixture(autouse=True, scope='session', params=ELASTICSEARCH_VERSIONS_TO_TEST) def setup_elasticsearch(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - es_image = f"docker.elastic.co/elasticsearch/elasticsearch:{version}" - - with docker_container( - f"elasticsearch-test-{version}", - es_image, - {str(ES_CONTAINER_PORT): ES_PORT}, - {"discovery.type": "single-node", "xpack.security.enabled": "false"}, - ): + es_image = f'docker.elastic.co/elasticsearch/elasticsearch:{version}' + + with docker_container(f'elasticsearch-test-{version}', es_image, {str(ES_CONTAINER_PORT): ES_PORT}, {'discovery.type': 'single-node', 'xpack.security.enabled': 'false'}): if not wait_for_true(bool_fn=ping_elasticsearch, tries=WAIT_FOR_ELASTICSEARCH_TIMEOUT, wait_time=2): - msg = f"Elasticsearch {version} failed to start" + msg = f'Elasticsearch {version} failed to start' raise ElasticsearchFailedToStartError(msg) - + yield + @pytest.fixture def es_client(self) -> Generator[Elasticsearch, None, None]: with Elasticsearch(hosts=[ES_URL]) as es_client: yield es_client + @pytest.fixture(autouse=True) def cleanup_elasticsearch_indices(self, es_client: Elasticsearch): cleanup_elasticsearch_indices(elasticsearch_client=es_client) yield cleanup_elasticsearch_indices(elasticsearch_client=es_client) + - @pytest.mark.skip(reason="Distributed Caches are unbounded") + @pytest.mark.skip(reason='Distributed Caches are unbounded') @override - def test_not_unbounded(self, store: BaseStore): ... + def test_not_unbounded(self, store: BaseStore): + ... + - @pytest.mark.skip(reason="Skip concurrent tests on distributed caches") + @pytest.mark.skip(reason='Skip concurrent tests on distributed caches') @override - def test_concurrent_operations(self, store: BaseStore): ... + def test_concurrent_operations(self, store: BaseStore): + ... + def test_put_put_two_indices(self, store: ElasticsearchStore, es_client: Elasticsearch): - store.put(collection="test_collection", key="test_key", value={"test": "test"}) - store.put(collection="test_collection_2", key="test_key", value={"test": "test"}) - assert store.get(collection="test_collection", key="test_key") == {"test": "test"} - assert store.get(collection="test_collection_2", key="test_key") == {"test": "test"} - - indices = es_client.options(ignore_status=404).indices.get(index="kv-store-e2e-test-*") + store.put(collection='test_collection', key='test_key', value={'test': 'test'}) + store.put(collection='test_collection_2', key='test_key', value={'test': 'test'}) + assert store.get(collection='test_collection', key='test_key') == {'test': 'test'} + assert store.get(collection='test_collection_2', key='test_key') == {'test': 'test'} + + indices = es_client.options(ignore_status=404).indices.get(index='kv-store-e2e-test-*') assert len(indices.body) == 2 - assert "kv-store-e2e-test-test_collection" in indices - assert "kv-store-e2e-test-test_collection_2" in indices + assert 'kv-store-e2e-test-test_collection' in indices + assert 'kv-store-e2e-test-test_collection_2' in indices -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") +@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') class TestElasticsearchStoreNativeMode(BaseTestElasticsearchStore): """Test Elasticsearch store in native mode (i.e. it stores flattened objects)""" + @override @pytest.fixture def store(self) -> ElasticsearchStore: - return ElasticsearchStore(url=ES_URL, index_prefix="kv-store-e2e-test", native_storage=True) + return ElasticsearchStore(url=ES_URL, index_prefix='kv-store-e2e-test', native_storage=True) + def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify values are stored as flattened objects, not JSON strings""" - store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) - + store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) + # Check raw Elasticsearch document using public sanitization methods # Note: We need to access these internal methods for testing the storage format - index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key="test_key") # pyright: ignore[reportPrivateUsage] - + index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key='test_key') # pyright: ignore[reportPrivateUsage] + response = es_client.get(index=index_name, id=doc_id) - assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"flattened": {"name": "Alice", "age": 30}}, - "created_at": IsStr(min_length=20, max_length=40), - } - ) - + assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'flattened': {'name': 'Alice', 'age': 30}}, 'created_at': IsStr(min_length=20, max_length=40)}) + # Test with TTL - store.put(collection="test", key="test_key", value={"name": "Bob", "age": 25}, ttl=10) + store.put(collection='test', key='test_key', value={'name': 'Bob', 'age': 25}, ttl=10) response = es_client.get(index=index_name, id=doc_id) - assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"flattened": {"name": "Bob", "age": 25}}, - "created_at": IsStr(min_length=20, max_length=40), - "expires_at": IsStr(min_length=20, max_length=40), - } - ) + assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'flattened': {'name': 'Bob', 'age': 25}}, 'created_at': IsStr(min_length=20, max_length=40), 'expires_at': IsStr(min_length=20, max_length=40)}) + def test_migration_from_non_native_mode(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify native mode can read a document with stringified data""" - index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key="legacy_key") # pyright: ignore[reportPrivateUsage] - - es_client.index( - index=index_name, id=doc_id, body={"collection": "test", "key": "legacy_key", "value": {"string": '{"legacy": "data"}'}} - ) + index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key='legacy_key') # pyright: ignore[reportPrivateUsage] + + es_client.index(index=index_name, id=doc_id, body={'collection': 'test', 'key': 'legacy_key', 'value': {'string': '{"legacy": "data"}'}}) es_client.indices.refresh(index=index_name) - - result = store.get(collection="test", key="legacy_key") - assert result == snapshot({"legacy": "data"}) + + result = store.get(collection='test', key='legacy_key') + assert result == snapshot({'legacy': 'data'}) -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") +@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') class TestElasticsearchStoreNonNativeMode(BaseTestElasticsearchStore): """Test Elasticsearch store in non-native mode (i.e. it stores stringified JSON values)""" + @override @pytest.fixture def store(self) -> ElasticsearchStore: - return ElasticsearchStore(url=ES_URL, index_prefix="kv-store-e2e-test", native_storage=False) + return ElasticsearchStore(url=ES_URL, index_prefix='kv-store-e2e-test', native_storage=False) + def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify values are stored as JSON strings""" - store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) - - index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key="test_key") # pyright: ignore[reportPrivateUsage] - + store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) + + index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key='test_key') # pyright: ignore[reportPrivateUsage] + response = es_client.get(index=index_name, id=doc_id) - assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"string": '{"age": 30, "name": "Alice"}'}, - "created_at": IsStr(min_length=20, max_length=40), - } - ) - + assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'string': '{"age": 30, "name": "Alice"}'}, 'created_at': IsStr(min_length=20, max_length=40)}) + # Test with TTL - store.put(collection="test", key="test_key", value={"name": "Bob", "age": 25}, ttl=10) + store.put(collection='test', key='test_key', value={'name': 'Bob', 'age': 25}, ttl=10) response = es_client.get(index=index_name, id=doc_id) - assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"string": '{"age": 25, "name": "Bob"}'}, - "created_at": IsStr(min_length=20, max_length=40), - "expires_at": IsStr(min_length=20, max_length=40), - } - ) + assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'string': '{"age": 25, "name": "Bob"}'}, 'created_at': IsStr(min_length=20, max_length=40), 'expires_at': IsStr(min_length=20, max_length=40)}) + def test_migration_from_native_mode(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify non-native mode can read native mode data""" - index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key="legacy_key") # pyright: ignore[reportPrivateUsage] - - es_client.index( - index=index_name, - id=doc_id, - body={"collection": "test", "key": "legacy_key", "value": {"flattened": {"name": "Alice", "age": 30}}}, - ) - + index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key='legacy_key') # pyright: ignore[reportPrivateUsage] + + es_client.index(index=index_name, id=doc_id, body={'collection': 'test', 'key': 'legacy_key', 'value': {'flattened': {'name': 'Alice', 'age': 30}}}) + es_client.indices.refresh(index=index_name) - - result = store.get(collection="test", key="legacy_key") - assert result == snapshot({"name": "Alice", "age": 30}) + + result = store.get(collection='test', key='legacy_key') + assert result == snapshot({'name': 'Alice', 'age': 30}) diff --git a/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py b/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py index c53c4f2b..1c586f2c 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py +++ b/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py @@ -13,24 +13,28 @@ from tests.code_gen.stores.base import BaseStoreTests -@pytest.mark.skipif(condition=detect_on_linux(), reason="KeyringStore is not available on Linux CI") +@pytest.mark.skipif(condition=detect_on_linux(), reason='KeyringStore is not available on Linux CI') class TestKeychainStore(BaseStoreTests): + @override @pytest.fixture def store(self) -> KeyringStore: # Use a test-specific service name to avoid conflicts - store = KeyringStore(service_name="py-key-value-test") - store.delete_many(collection="test", keys=["test", "test_2"]) - store.delete_many(collection="test_collection", keys=["test_key"]) - + store = KeyringStore(service_name='py-key-value-test') + store.delete_many(collection='test', keys=['test', 'test_2']) + store.delete_many(collection='test_collection', keys=['test_key']) + return store + @override - @pytest.mark.skip(reason="We do not test boundedness of keyring stores") - def test_not_unbounded(self, store: BaseStore): ... + @pytest.mark.skip(reason='We do not test boundedness of keyring stores') + def test_not_unbounded(self, store: BaseStore): + ... + @override - @pytest.mark.skipif(condition=detect_on_windows(), reason="Keyrings do not support large values on Windows") + @pytest.mark.skipif(condition=detect_on_windows(), reason='Keyrings do not support large values on Windows') @PositiveCases.parametrize(cases=[LARGE_DATA_CASES]) def test_get_large_put_get(self, store: BaseStore, data: dict[str, Any], json: str, round_trip: dict[str, Any]): super().test_get_large_put_get(store, data, json, round_trip=round_trip) diff --git a/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py b/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py index 631cf255..1783124e 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py +++ b/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py @@ -9,11 +9,13 @@ class TestMemoryStore(BaseStoreTests): + @override @pytest.fixture def store(self) -> MemoryStore: return MemoryStore(max_entries_per_collection=500) + def test_seed(self): - store = MemoryStore(max_entries_per_collection=500, seed={"test_collection": {"test_key": {"obj_key": "obj_value"}}}) - assert store.get(key="test_key", collection="test_collection") == {"obj_key": "obj_value"} + store = MemoryStore(max_entries_per_collection=500, seed={'test_collection': {'test_key': {'obj_key': 'obj_value'}}}) + assert store.get(key='test_key', collection='test_collection') == {'obj_key': 'obj_value'} diff --git a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py index a9a3a9a3..7c7bc119 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py @@ -22,14 +22,14 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin # MongoDB test configuration -MONGODB_HOST = "localhost" +MONGODB_HOST = 'localhost' MONGODB_HOST_PORT = 27017 -MONGODB_TEST_DB = "kv-store-adapter-tests" +MONGODB_TEST_DB = 'kv-store-adapter-tests' WAIT_FOR_MONGODB_TIMEOUT = 30 -# Older supported version + # Older supported version # Latest stable version -MONGODB_VERSIONS_TO_TEST = ["5.0", "8.0"] +MONGODB_VERSIONS_TO_TEST = ['5.0', '8.0'] def ping_mongodb() -> bool: @@ -38,7 +38,7 @@ def ping_mongodb() -> bool: _ = client.list_database_names() except Exception: return False - + return True @@ -49,21 +49,14 @@ class MongoDBFailedToStartError(Exception): def test_managed_entry_document_conversion_native_mode(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) - - assert document == snapshot( - { - "key": "test", - "value": {"object": {"test": "test"}}, - "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), - "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), - } - ) - + + managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(key='test', managed_entry=managed_entry, native_storage=True) + + assert document == snapshot({'key': 'test', 'value': {'object': {'test': 'test'}}, 'created_at': datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), 'expires_at': datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc)}) + round_trip_managed_entry = document_to_managed_entry(document=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -73,21 +66,14 @@ def test_managed_entry_document_conversion_native_mode(): def test_managed_entry_document_conversion_legacy_mode(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) - - assert document == snapshot( - { - "key": "test", - "value": {"string": '{"test": "test"}'}, - "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), - "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), - } - ) - + + managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(key='test', managed_entry=managed_entry, native_storage=False) + + assert document == snapshot({'key': 'test', 'value': {'string': '{"test": "test"}'}, 'created_at': datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), 'expires_at': datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc)}) + round_trip_managed_entry = document_to_managed_entry(document=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -101,114 +87,110 @@ def clean_mongodb_database(store: MongoDBStore) -> None: class BaseMongoDBStoreTests(ContextManagerStoreTestMixin, BaseStoreTests): """Base class for MongoDB store tests.""" + - @pytest.fixture(autouse=True, scope="session", params=MONGODB_VERSIONS_TO_TEST) + @pytest.fixture(autouse=True, scope='session', params=MONGODB_VERSIONS_TO_TEST) def setup_mongodb(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container(f"mongodb-test-{version}", f"mongo:{version}", {str(MONGODB_HOST_PORT): MONGODB_HOST_PORT}): + + with docker_container(f'mongodb-test-{version}', f'mongo:{version}', {str(MONGODB_HOST_PORT): MONGODB_HOST_PORT}): if not wait_for_true(bool_fn=ping_mongodb, tries=WAIT_FOR_MONGODB_TIMEOUT, wait_time=1): - msg = f"MongoDB {version} failed to start" + msg = f'MongoDB {version} failed to start' raise MongoDBFailedToStartError(msg) - + yield + - @pytest.mark.skip(reason="Distributed Caches are unbounded") + @pytest.mark.skip(reason='Distributed Caches are unbounded') @override - def test_not_unbounded(self, store: BaseStore): ... + def test_not_unbounded(self, store: BaseStore): + ... + def test_mongodb_collection_name_sanitization(self, store: MongoDBStore): """Tests that a special characters in the collection name will not raise an error.""" - store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) - assert store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} - + store.put(collection='test_collection!@#$%^&*()', key='test_key', value={'test': 'test'}) + assert store.get(collection='test_collection!@#$%^&*()', key='test_key') == {'test': 'test'} + collections = store.collections() - assert collections == snapshot(["test_collection_-daf4a2ec"]) + assert collections == snapshot(['test_collection_-daf4a2ec']) -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not available') class TestMongoDBStoreNativeMode(BaseMongoDBStoreTests): """Test MongoDBStore with native_storage=True (default).""" + @override @pytest.fixture def store(self, setup_mongodb: None) -> MongoDBStore: - store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=f"{MONGODB_TEST_DB}-native", native_storage=True) - + store = MongoDBStore(url=f'mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}', db_name=f'{MONGODB_TEST_DB}-native', native_storage=True) + clean_mongodb_database(store=store) - + return store + def test_value_stored_as_bson_dict(self, store: MongoDBStore): """Verify values are stored as BSON dicts, not JSON strings.""" - store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) - + store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) + # Get the raw MongoDB document - store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - doc = collection.find_one({"key": "test_key"}) - - assert doc == snapshot( - { - "_id": IsInstance(expected_type=ObjectId), - "key": "test_key", - "created_at": IsDatetime(), - "value": {"object": {"name": "Alice", "age": 30}}, - } - ) + doc = collection.find_one({'key': 'test_key'}) + + assert doc == snapshot({'_id': IsInstance(expected_type=ObjectId), 'key': 'test_key', 'created_at': IsDatetime(), 'value': {'object': {'name': 'Alice', 'age': 30}}}) + def test_migration_from_legacy_mode(self, store: MongoDBStore): """Verify native mode can read legacy JSON string data.""" - store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - - collection.insert_one({"key": "legacy_key", "value": {"string": '{"legacy": "data"}'}}) - - result = store.get(collection="test", key="legacy_key") - assert result == {"legacy": "data"} + + collection.insert_one({'key': 'legacy_key', 'value': {'string': '{"legacy": "data"}'}}) + + result = store.get(collection='test', key='legacy_key') + assert result == {'legacy': 'data'} -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") +@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not available') class TestMongoDBStoreNonNativeMode(BaseMongoDBStoreTests): """Test MongoDBStore with native_storage=False (legacy mode) for backward compatibility.""" + @override @pytest.fixture def store(self, setup_mongodb: None) -> MongoDBStore: - store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB, native_storage=False) - + store = MongoDBStore(url=f'mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}', db_name=MONGODB_TEST_DB, native_storage=False) + clean_mongodb_database(store=store) - + return store + def test_value_stored_as_json(self, store: MongoDBStore): """Verify values are stored as JSON strings.""" - store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) - + store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) + # Get the raw MongoDB document - store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - doc = collection.find_one({"key": "test_key"}) - - assert doc == snapshot( - { - "_id": IsInstance(expected_type=ObjectId), - "key": "test_key", - "created_at": IsDatetime(), - "value": {"string": '{"age": 30, "name": "Alice"}'}, - } - ) + doc = collection.find_one({'key': 'test_key'}) + + assert doc == snapshot({'_id': IsInstance(expected_type=ObjectId), 'key': 'test_key', 'created_at': IsDatetime(), 'value': {'string': '{"age": 30, "name": "Alice"}'}}) + def test_migration_from_native_mode(self, store: MongoDBStore): """Verify non-native mode can read native mode data.""" - store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - - collection.insert_one({"key": "legacy_key", "value": {"object": {"name": "Alice", "age": 30}}}) - - result = store.get(collection="test", key="legacy_key") - assert result == {"name": "Alice", "age": 30} + + collection.insert_one({'key': 'legacy_key', 'value': {'object': {'name': 'Alice', 'age': 30}}}) + + result = store.get(collection='test', key='legacy_key') + assert result == {'name': 'Alice', 'age': 30} diff --git a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py index 4620fba8..a7dd100e 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py +++ b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py @@ -19,28 +19,26 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin # Redis test configuration -REDIS_HOST = "localhost" +REDIS_HOST = 'localhost' REDIS_PORT = 6379 REDIS_DB = 15 # Use a separate database for tests WAIT_FOR_REDIS_TIMEOUT = 30 -REDIS_VERSIONS_TO_TEST = ["4.0.0", "7.0.0"] +REDIS_VERSIONS_TO_TEST = ['4.0.0', '7.0.0'] def test_managed_entry_document_conversion(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) + + managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) document = managed_entry_to_json(managed_entry=managed_entry) - - assert document == snapshot( - '{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}' - ) - + + assert document == snapshot('{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}') + round_trip_managed_entry = json_to_managed_entry(json_str=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -59,18 +57,20 @@ class RedisFailedToStartError(Exception): pass -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") +@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') class TestRedisStore(ContextManagerStoreTestMixin, BaseStoreTests): - @pytest.fixture(autouse=True, scope="session", params=REDIS_VERSIONS_TO_TEST) + + @pytest.fixture(autouse=True, scope='session', params=REDIS_VERSIONS_TO_TEST) def setup_redis(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container("redis-test", f"redis:{version}", {"6379": REDIS_PORT}): + + with docker_container('redis-test', f'redis:{version}', {'6379': REDIS_PORT}): if not wait_for_true(bool_fn=ping_redis, tries=30, wait_time=1): - msg = "Redis failed to start" + msg = 'Redis failed to start' raise RedisFailedToStartError(msg) - + yield + @override @pytest.fixture @@ -80,28 +80,32 @@ def store(self, setup_redis: RedisStore) -> RedisStore: redis_store = RedisStore(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) _ = redis_store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] return redis_store + def test_redis_url_connection(self): """Test Redis store creation with URL.""" - redis_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" + redis_url = f'redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}' store = RedisStore(url=redis_url) _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] - store.put(collection="test", key="url_test", value={"test": "value"}) - result = store.get(collection="test", key="url_test") - assert result == {"test": "value"} + store.put(collection='test', key='url_test', value={'test': 'value'}) + result = store.get(collection='test', key='url_test') + assert result == {'test': 'value'} + def test_redis_client_connection(self): """Test Redis store creation with existing client.""" from redis import Redis - + client = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) store = RedisStore(client=client) - + _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] - store.put(collection="test", key="client_test", value={"test": "value"}) - result = store.get(collection="test", key="client_test") - assert result == {"test": "value"} + store.put(collection='test', key='client_test', value={'test': 'value'}) + result = store.get(collection='test', key='client_test') + assert result == {'test': 'value'} + - @pytest.mark.skip(reason="Distributed Caches are unbounded") + @pytest.mark.skip(reason='Distributed Caches are unbounded') @override - def test_not_unbounded(self, store: BaseStore): ... + def test_not_unbounded(self, store: BaseStore): + ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py index 2e04cd08..29ab08bd 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py @@ -14,51 +14,56 @@ class TestRocksDBStore(ContextManagerStoreTestMixin, BaseStoreTests): + @override @pytest.fixture def store(self) -> Generator[RocksDBStore, None, None]: """Create a RocksDB store for testing.""" # Create a temporary directory for the RocksDB database with TemporaryDirectory() as temp_dir: - db_path = Path(temp_dir) / "test_db" + db_path = Path(temp_dir) / 'test_db' rocksdb_store = RocksDBStore(path=db_path) yield rocksdb_store + def test_rocksdb_path_connection(self): """Test RocksDB store creation with path.""" temp_dir = TemporaryDirectory() - db_path = Path(temp_dir.name) / "path_test_db" - + db_path = Path(temp_dir.name) / 'path_test_db' + store = RocksDBStore(path=db_path) - - store.put(collection="test", key="path_test", value={"test": "value"}) - result = store.get(collection="test", key="path_test") - assert result == {"test": "value"} - + + store.put(collection='test', key='path_test', value={'test': 'value'}) + result = store.get(collection='test', key='path_test') + assert result == {'test': 'value'} + store.close() temp_dir.cleanup() + def test_rocksdb_db_connection(self): """Test RocksDB store creation with existing DB instance.""" from rocksdict import Options, Rdict - + temp_dir = TemporaryDirectory() - db_path = Path(temp_dir.name) / "db_test_db" + db_path = Path(temp_dir.name) / 'db_test_db' db_path.mkdir(parents=True, exist_ok=True) - + opts = Options() opts.create_if_missing(True) db = Rdict(str(db_path), options=opts) - + store = RocksDBStore(db=db) - - store.put(collection="test", key="db_test", value={"test": "value"}) - result = store.get(collection="test", key="db_test") - assert result == {"test": "value"} - + + store.put(collection='test', key='db_test', value={'test': 'value'}) + result = store.get(collection='test', key='db_test') + assert result == {'test': 'value'} + store.close() temp_dir.cleanup() + - @pytest.mark.skip(reason="Local disk stores are unbounded") + @pytest.mark.skip(reason='Local disk stores are unbounded') @override - def test_not_unbounded(self, store: BaseStore): ... + def test_not_unbounded(self, store: BaseStore): + ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py b/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py index 1ee92614..724cb462 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py +++ b/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py @@ -9,6 +9,7 @@ class TestSimpleStore(BaseStoreTests): + @override @pytest.fixture def store(self) -> SimpleStore: diff --git a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py index 69349194..720d0293 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py +++ b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py @@ -13,33 +13,33 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin # Valkey test configuration -VALKEY_HOST = "localhost" +VALKEY_HOST = 'localhost' VALKEY_PORT = 6380 # normally 6379, avoid clashing with Redis tests VALKEY_DB = 15 VALKEY_CONTAINER_PORT = 6379 WAIT_FOR_VALKEY_TIMEOUT = 30 -# Released Apr 2024 + # Released Apr 2024 # Released Sep 2024 # Released Oct 2025 -VALKEY_VERSIONS_TO_TEST = ["7.2.5", "8.0.0", "9.0.0"] +VALKEY_VERSIONS_TO_TEST = ['7.2.5', '8.0.0', '9.0.0'] class ValkeyFailedToStartError(Exception): pass -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") -@pytest.mark.skipif(detect_on_windows(), reason="Valkey is not supported on Windows") +@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') +@pytest.mark.skipif(detect_on_windows(), reason='Valkey is not supported on Windows') class TestValkeyStore(ContextManagerStoreTestMixin, BaseStoreTests): + def get_valkey_client(self): - from glide_sync.config import GlideClientConfiguration, NodeAddress from glide_sync.glide_client import GlideClient - - client_config: GlideClientConfiguration = GlideClientConfiguration( - addresses=[NodeAddress(host=VALKEY_HOST, port=VALKEY_PORT)], database_id=VALKEY_DB - ) + from glide_sync.config import GlideClientConfiguration, NodeAddress + + client_config: GlideClientConfiguration = GlideClientConfiguration(addresses=[NodeAddress(host=VALKEY_HOST, port=VALKEY_PORT)], database_id=VALKEY_DB) return GlideClient.create(config=client_config) + def ping_valkey(self) -> bool: client = None @@ -54,31 +54,35 @@ def ping_valkey(self) -> bool: if client is not None: with contextlib.suppress(Exception): client.close() + - @pytest.fixture(scope="session", params=VALKEY_VERSIONS_TO_TEST) + @pytest.fixture(scope='session', params=VALKEY_VERSIONS_TO_TEST) def setup_valkey(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container(f"valkey-test-{version}", f"valkey/valkey:{version}", {str(VALKEY_CONTAINER_PORT): VALKEY_PORT}): + + with docker_container(f'valkey-test-{version}', f'valkey/valkey:{version}', {str(VALKEY_CONTAINER_PORT): VALKEY_PORT}): if not wait_for_true(bool_fn=self.ping_valkey, tries=WAIT_FOR_VALKEY_TIMEOUT, wait_time=1): - msg = f"Valkey {version} failed to start" + msg = f'Valkey {version} failed to start' raise ValkeyFailedToStartError(msg) - + yield + @override @pytest.fixture def store(self, setup_valkey: None): from key_value.sync.code_gen.stores.valkey import ValkeyStore - + store: ValkeyStore = ValkeyStore(host=VALKEY_HOST, port=VALKEY_PORT, db=VALKEY_DB) - + # This is a syncronous client client = self.get_valkey_client() _ = client.flushdb() - + return store + - @pytest.mark.skip(reason="Distributed Caches are unbounded") + @pytest.mark.skip(reason='Distributed Caches are unbounded') @override - def test_not_unbounded(self, store: BaseStore): ... + def test_not_unbounded(self, store: BaseStore): + ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py b/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py index c7a5fd74..7d145ea6 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py +++ b/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py @@ -12,28 +12,30 @@ from tests.code_gen.stores.base import BaseStoreTests # Vault test configuration -VAULT_HOST = "localhost" +VAULT_HOST = 'localhost' VAULT_PORT = 8200 -VAULT_TOKEN = "dev-root-token" # noqa: S105 -VAULT_MOUNT_POINT = "secret" +VAULT_TOKEN = 'dev-root-token' # noqa: S105 +VAULT_MOUNT_POINT = 'secret' VAULT_CONTAINER_PORT = 8200 WAIT_FOR_VAULT_TIMEOUT = 30 -# Released Oct 2022 + # Released Oct 2022 # Released Oct 2025 -VAULT_VERSIONS_TO_TEST = ["1.12.0", "1.21.0"] +VAULT_VERSIONS_TO_TEST = ['1.12.0', '1.21.0'] class VaultFailedToStartError(Exception): pass -@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") +@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') class TestVaultStore(BaseStoreTests): + def get_vault_client(self): import hvac - - return hvac.Client(url=f"http://{VAULT_HOST}:{VAULT_PORT}", token=VAULT_TOKEN) + + return hvac.Client(url=f'http://{VAULT_HOST}:{VAULT_PORT}', token=VAULT_TOKEN) + def ping_vault(self) -> bool: try: @@ -41,45 +43,44 @@ def ping_vault(self) -> bool: return client.sys.is_initialized() # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] except Exception: return False + - @pytest.fixture(scope="session", params=VAULT_VERSIONS_TO_TEST) + @pytest.fixture(scope='session', params=VAULT_VERSIONS_TO_TEST) def setup_vault(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container( - f"vault-test-{version}", - f"hashicorp/vault:{version}", - {str(VAULT_CONTAINER_PORT): VAULT_PORT}, - environment={"VAULT_DEV_ROOT_TOKEN_ID": VAULT_TOKEN, "VAULT_DEV_LISTEN_ADDRESS": "0.0.0.0:8200"}, - ): + + with docker_container(f'vault-test-{version}', f'hashicorp/vault:{version}', {str(VAULT_CONTAINER_PORT): VAULT_PORT}, environment={'VAULT_DEV_ROOT_TOKEN_ID': VAULT_TOKEN, 'VAULT_DEV_LISTEN_ADDRESS': '0.0.0.0:8200'}): if not wait_for_true(bool_fn=self.ping_vault, tries=WAIT_FOR_VAULT_TIMEOUT, wait_time=1): - msg = f"Vault {version} failed to start" + msg = f'Vault {version} failed to start' raise VaultFailedToStartError(msg) - + yield + @override @pytest.fixture def store(self, setup_vault: None): from key_value.sync.code_gen.stores.vault import VaultStore - - store: VaultStore = VaultStore(url=f"http://{VAULT_HOST}:{VAULT_PORT}", token=VAULT_TOKEN, mount_point=VAULT_MOUNT_POINT) - + + store: VaultStore = VaultStore(url=f'http://{VAULT_HOST}:{VAULT_PORT}', token=VAULT_TOKEN, mount_point=VAULT_MOUNT_POINT) + # Clean up any existing data - best effort, ignore errors client = self.get_vault_client() try: # List all secrets and delete them - secrets_list = client.secrets.kv.v2.list_secrets(path="", mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] - if secrets_list and "data" in secrets_list and ("keys" in secrets_list["data"]): - for key in secrets_list["data"]["keys"]: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + secrets_list = client.secrets.kv.v2.list_secrets(path='', mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] + if secrets_list and 'data' in secrets_list and ('keys' in secrets_list['data']): + for key in secrets_list['data']['keys']: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] # Best effort cleanup - ignore individual deletion failures - client.secrets.kv.v2.delete_metadata_and_all_versions(path=key.rstrip("/"), mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] + client.secrets.kv.v2.delete_metadata_and_all_versions(path=key.rstrip('/'), mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] except Exception: # noqa: S110 # Cleanup is best-effort, ignore all errors pass - + return store + - @pytest.mark.skip(reason="Distributed Caches are unbounded") + @pytest.mark.skip(reason='Distributed Caches are unbounded') @override - def test_not_unbounded(self, store: BaseStore): ... + def test_not_unbounded(self, store: BaseStore): + ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py b/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py index 65105b16..f90e948c 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py +++ b/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py @@ -13,27 +13,31 @@ if TYPE_CHECKING: from key_value.sync.code_gen.stores.windows_registry.store import WindowsRegistryStore -TEST_REGISTRY_PATH = "software\\py-key-value-test" +TEST_REGISTRY_PATH = 'software\\py-key-value-test' -@pytest.mark.skipif(condition=not detect_on_windows(), reason="WindowsRegistryStore is only available on Windows") +@pytest.mark.skipif(condition=not detect_on_windows(), reason='WindowsRegistryStore is only available on Windows') class TestWindowsRegistryStore(BaseStoreTests): + def cleanup(self): from winreg import HKEY_CURRENT_USER - + from key_value.sync.code_gen.stores.windows_registry.utils import delete_sub_keys - + delete_sub_keys(hive=HKEY_CURRENT_USER, sub_key=TEST_REGISTRY_PATH) + @override @pytest.fixture - def store(self) -> "WindowsRegistryStore": + def store(self) -> 'WindowsRegistryStore': from key_value.sync.code_gen.stores.windows_registry.store import WindowsRegistryStore - + self.cleanup() - - return WindowsRegistryStore(registry_path=TEST_REGISTRY_PATH, hive="HKEY_CURRENT_USER") + + return WindowsRegistryStore(registry_path=TEST_REGISTRY_PATH, hive='HKEY_CURRENT_USER') + @override - @pytest.mark.skip(reason="We do not test boundedness of registry stores") - def test_not_unbounded(self, store: BaseStore): ... + @pytest.mark.skip(reason='We do not test boundedness of registry stores') + def test_not_unbounded(self, store: BaseStore): + ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py index 1def5384..72efb239 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py @@ -10,114 +10,118 @@ class TestCompressionWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> CompressionWrapper: # Set min_size to 0 so all values get compressed for testing return CompressionWrapper(key_value=memory_store, min_size_to_compress=0) + def test_compression_small_value_not_compressed(self, memory_store: MemoryStore): # With default min_size (1024), small values shouldn't be compressed compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=1024) - - small_value = {"test": "value"} - compression_store.put(collection="test", key="test", value=small_value) - + + small_value = {'test': 'value'} + compression_store.put(collection='test', key='test', value=small_value) + # Check the underlying store - should NOT be compressed - raw_value = memory_store.get(collection="test", key="test") + raw_value = memory_store.get(collection='test', key='test') assert raw_value is not None assert raw_value == small_value - assert "__compressed_data__" not in raw_value - + assert '__compressed_data__' not in raw_value + # Retrieve through wrapper - result = compression_store.get(collection="test", key="test") + result = compression_store.get(collection='test', key='test') assert result == small_value + def test_compression_large_value_compressed(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=100) - + # Create a large value - large_value = {"data": "x" * 1000, "more_data": "y" * 1000} - compression_store.put(collection="test", key="test", value=large_value) - + large_value = {'data': 'x' * 1000, 'more_data': 'y' * 1000} + compression_store.put(collection='test', key='test', value=large_value) + # Check the underlying store - should be compressed - raw_value = memory_store.get(collection="test", key="test") + raw_value = memory_store.get(collection='test', key='test') assert raw_value is not None - assert "__compressed_data__" in raw_value - assert "__compression_version__" in raw_value - assert isinstance(raw_value["__compressed_data__"], str) - + assert '__compressed_data__' in raw_value + assert '__compression_version__' in raw_value + assert isinstance(raw_value['__compressed_data__'], str) + # Retrieve through wrapper - should decompress automatically - result = compression_store.get(collection="test", key="test") + result = compression_store.get(collection='test', key='test') assert result == large_value + def test_compression_many_operations(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - - keys = ["k1", "k2", "k3"] - values = [{"data": "value1"}, {"data": "value2"}, {"data": "value3"}] - - compression_store.put_many(collection="test", keys=keys, values=values) - + + keys = ['k1', 'k2', 'k3'] + values = [{'data': 'value1'}, {'data': 'value2'}, {'data': 'value3'}] + + compression_store.put_many(collection='test', keys=keys, values=values) + # Check underlying store - all should be compressed for key in keys: - raw_value = memory_store.get(collection="test", key=key) + raw_value = memory_store.get(collection='test', key=key) assert raw_value is not None - assert "__compressed_data__" in raw_value - + assert '__compressed_data__' in raw_value + # Retrieve through wrapper - results = compression_store.get_many(collection="test", keys=keys) + results = compression_store.get_many(collection='test', keys=keys) assert results == values + def test_compression_already_compressed_not_recompressed(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Manually create a compressed value - compressed_value = { - "__compressed_data__": "H4sIAAAAAAAAA6tWKkktLlGyUlAqS8wpTtVRKi1OLUpVslIqLU4tUqoFAJRxMHkfAAAA", - "__compression_version__": 1, - "__compression_algorithm__": "gzip", - } - + compressed_value = {'__compressed_data__': 'H4sIAAAAAAAAA6tWKkktLlGyUlAqS8wpTtVRKi1OLUpVslIqLU4tUqoFAJRxMHkfAAAA', '__compression_version__': 1, '__compression_algorithm__': 'gzip'} + # Should not try to compress again result = compression_store._compress_value(value=compressed_value) # pyright: ignore[reportPrivateUsage] assert result == compressed_value + def test_decompression_handles_uncompressed_data(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Store uncompressed data directly in underlying store - uncompressed_value = {"test": "value"} - memory_store.put(collection="test", key="test", value=uncompressed_value) - + uncompressed_value = {'test': 'value'} + memory_store.put(collection='test', key='test', value=uncompressed_value) + # Should return as-is when retrieved through compression wrapper - result = compression_store.get(collection="test", key="test") + result = compression_store.get(collection='test', key='test') assert result == uncompressed_value + def test_decompression_handles_corrupted_data(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Store corrupted compressed data - corrupted_value = {"__compressed_data__": "invalid-base64-data!!!", "__compression_version__": 1} - memory_store.put(collection="test", key="test", value=corrupted_value) - + corrupted_value = {'__compressed_data__': 'invalid-base64-data!!!', '__compression_version__': 1} + memory_store.put(collection='test', key='test', value=corrupted_value) + # Should return the corrupted value as-is rather than crashing - result = compression_store.get(collection="test", key="test") + result = compression_store.get(collection='test', key='test') assert result == corrupted_value + def test_compression_size_reduction(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Create a highly compressible value (repeated data) - large_value = {"data": "x" * 10000} - - compression_store.put(collection="test", key="test", value=large_value) - + large_value = {'data': 'x' * 10000} + + compression_store.put(collection='test', key='test', value=large_value) + # Check the compressed size - raw_value = memory_store.get(collection="test", key="test") + raw_value = memory_store.get(collection='test', key='test') assert raw_value is not None - compressed_data = raw_value["__compressed_data__"] - + compressed_data = raw_value['__compressed_data__'] + # Compressed data should be significantly smaller than original # Original is ~10KB, compressed should be much smaller due to repetition assert len(compressed_data) < 1000 # Should be less than 1KB diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py index 40323f9f..e4ee1018 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py @@ -10,85 +10,107 @@ from key_value.sync.code_gen.wrappers.default_value import DefaultValueWrapper from tests.code_gen.stores.base import BaseStoreTests -TEST_KEY_1 = "test_key_1" -TEST_KEY_2 = "test_key_2" -TEST_COLLECTION = "test_collection" -DEFAULT_VALUE = {"obj_key": "obj_value"} +TEST_KEY_1 = 'test_key_1' +TEST_KEY_2 = 'test_key_2' +TEST_COLLECTION = 'test_collection' +DEFAULT_VALUE = {'obj_key': 'obj_value'} DEFAULT_TTL = 100 class TestDefaultValueWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> DefaultValueWrapper: return DefaultValueWrapper(key_value=memory_store, default_value=DEFAULT_VALUE, default_ttl=DEFAULT_TTL) + def test_default_value(self, store: BaseStore): assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_1) == DEFAULT_VALUE assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_1) == (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)) assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, DEFAULT_VALUE] - assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ - (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), - (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), - ] - - store.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value={"key_2": "value_2"}, ttl=200) - assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_2) == {"key_2": "value_2"} - assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_2) == ({"key_2": "value_2"}, IsFloat(approx=200)) - assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, {"key_2": "value_2"}] - assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ - (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), - ({"key_2": "value_2"}, IsFloat(approx=200)), - ] + assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [(DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL))] + + store.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value={'key_2': 'value_2'}, ttl=200) + assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_2) == {'key_2': 'value_2'} + assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_2) == ({'key_2': 'value_2'}, IsFloat(approx=200)) + assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, {'key_2': 'value_2'}] + assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [(DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), ({'key_2': 'value_2'}, IsFloat(approx=200))] + @override @pytest.mark.skip - def test_empty_get(self, store: BaseStore): ... + def test_empty_get(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_put_put_get_many_missing_one(self, store: BaseStore): ... + def test_put_put_get_many_missing_one(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_empty_ttl(self, store: BaseStore): ... + def test_empty_ttl(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_get_put_get(self, store: BaseStore): ... + def test_get_put_get(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_get_put_get_delete_get(self, store: BaseStore): ... + def test_get_put_get_delete_get(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_put_get_delete_get(self, store: BaseStore): ... + def test_put_get_delete_get(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): ... + def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): ... + def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_get_put_get_put_delete_get(self, store: BaseStore): ... + def test_get_put_get_put_delete_get(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_put_many_delete_delete_get_many(self, store: BaseStore): ... + def test_put_many_delete_delete_get_many(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_put_expired_get_none(self, store: BaseStore): ... + def test_put_expired_get_none(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_not_unbounded(self, store: BaseStore): ... + def test_not_unbounded(self, store: BaseStore): + ... + @override @pytest.mark.skip - def test_concurrent_operations(self, store: BaseStore): ... + def test_concurrent_operations(self, store: BaseStore): + ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py index e20cb7c8..a64d2500 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py @@ -20,150 +20,160 @@ def fernet() -> Fernet: class TestFernetEncryptionWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore, fernet: Fernet) -> FernetEncryptionWrapper: return FernetEncryptionWrapper(key_value=memory_store, fernet=fernet) + def test_encryption_encrypts_value(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that values are actually encrypted in the underlying store.""" - original_value = {"test": "value", "number": 123} - store.put(collection="test", key="test", value=original_value) - + original_value = {'test': 'value', 'number': 123} + store.put(collection='test', key='test', value=original_value) + # Check the underlying store - should be encrypted - raw_value = memory_store.get(collection="test", key="test") + raw_value = memory_store.get(collection='test', key='test') assert raw_value is not None - assert "__encrypted_data__" in raw_value - assert "__encryption_version__" in raw_value - assert isinstance(raw_value["__encrypted_data__"], str) - + assert '__encrypted_data__' in raw_value + assert '__encryption_version__' in raw_value + assert isinstance(raw_value['__encrypted_data__'], str) + # The encrypted data should not contain the original value - assert "test" not in str(raw_value) - assert "value" not in str(raw_value) - + assert 'test' not in str(raw_value) + assert 'value' not in str(raw_value) + # Retrieve through wrapper - should decrypt automatically - result = store.get(collection="test", key="test") + result = store.get(collection='test', key='test') assert result == original_value + def test_encryption_with_wrong_encryption_version(self, store: FernetEncryptionWrapper): """Test that encryption fails with the wrong encryption version.""" store.encryption_version = 2 - original_value = {"test": "value"} - store.put(collection="test", key="test", value=original_value) - - assert store.get(collection="test", key="test") is not None + original_value = {'test': 'value'} + store.put(collection='test', key='test', value=original_value) + + assert store.get(collection='test', key='test') is not None store.encryption_version = 1 - + with pytest.raises(DecryptionError): - store.get(collection="test", key="test") + store.get(collection='test', key='test') + def test_encryption_with_string_key(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that encryption works with a string key.""" - original_value = {"test": "value"} - store.put(collection="test", key="test", value=original_value) - - round_trip_value = store.get(collection="test", key="test") + original_value = {'test': 'value'} + store.put(collection='test', key='test', value=original_value) + + round_trip_value = store.get(collection='test', key='test') assert round_trip_value == original_value - - raw_result = memory_store.get(collection="test", key="test") - assert raw_result == snapshot({"__encrypted_data__": IsStr(min_length=32), "__encryption_version__": 1}) + + raw_result = memory_store.get(collection='test', key='test') + assert raw_result == snapshot({'__encrypted_data__': IsStr(min_length=32), '__encryption_version__': 1}) + def test_encryption_many_operations(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that encryption works with put_many and get_many.""" - keys = ["k1", "k2", "k3"] - values = [{"data": "value1"}, {"data": "value2"}, {"data": "value3"}] - - store.put_many(collection="test", keys=keys, values=values) - + keys = ['k1', 'k2', 'k3'] + values = [{'data': 'value1'}, {'data': 'value2'}, {'data': 'value3'}] + + store.put_many(collection='test', keys=keys, values=values) + # Check underlying store - all should be encrypted for key in keys: - raw_value = memory_store.get(collection="test", key=key) + raw_value = memory_store.get(collection='test', key=key) assert raw_value is not None - assert "__encrypted_data__" in raw_value - + assert '__encrypted_data__' in raw_value + # Retrieve through wrapper - results = store.get_many(collection="test", keys=keys) + results = store.get_many(collection='test', keys=keys) assert results == values + def test_decryption_handles_unencrypted_data(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that unencrypted data is returned as-is.""" # Store unencrypted data directly in underlying store - unencrypted_value = {"test": "value"} - memory_store.put(collection="test", key="test", value=unencrypted_value) - + unencrypted_value = {'test': 'value'} + memory_store.put(collection='test', key='test', value=unencrypted_value) + # Should return as-is when retrieved through encryption wrapper - result = store.get(collection="test", key="test") + result = store.get(collection='test', key='test') assert result == unencrypted_value + def test_decryption_handles_corrupted_data(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that corrupted encrypted data is handled gracefully.""" - + # Store corrupted encrypted data - corrupted_value = {"__encrypted_data__": "invalid-encrypted-data!!!", "__encryption_version__": 1} - memory_store.put(collection="test", key="test", value=corrupted_value) - + corrupted_value = {'__encrypted_data__': 'invalid-encrypted-data!!!', '__encryption_version__': 1} + memory_store.put(collection='test', key='test', value=corrupted_value) + with pytest.raises(DecryptionError): - store.get(collection="test", key="test") + store.get(collection='test', key='test') + def test_decryption_ignores_corrupted_data(self, memory_store: MemoryStore, fernet: Fernet): """Test that corrupted encrypted data is ignored.""" store = FernetEncryptionWrapper(key_value=memory_store, fernet=fernet, raise_on_decryption_error=False) - + # Store corrupted encrypted data - corrupted_value = {"__encrypted_data__": "invalid-encrypted-data!!!", "__encryption_version__": 1} - memory_store.put(collection="test", key="test", value=corrupted_value) - - assert store.get(collection="test", key="test") is None + corrupted_value = {'__encrypted_data__': 'invalid-encrypted-data!!!', '__encryption_version__': 1} + memory_store.put(collection='test', key='test', value=corrupted_value) + + assert store.get(collection='test', key='test') is None + def test_decryption_with_multi_fernet(self, memory_store: MemoryStore): """Test that decryption works with a MultiFernet.""" first_fernet = Fernet(key=Fernet.generate_key()) first_fernet_store = FernetEncryptionWrapper(key_value=memory_store, fernet=first_fernet) - original_value = {"test": "value"} - first_fernet_store.put(collection="test", key="test", value=original_value) - assert first_fernet_store.get(collection="test", key="test") == original_value - + original_value = {'test': 'value'} + first_fernet_store.put(collection='test', key='test', value=original_value) + assert first_fernet_store.get(collection='test', key='test') == original_value + second_fernet = Fernet(key=Fernet.generate_key()) multi_fernet = MultiFernet([second_fernet, first_fernet]) multi_fernet_store = FernetEncryptionWrapper(key_value=memory_store, fernet=multi_fernet) - assert multi_fernet_store.get(collection="test", key="test") == original_value + assert multi_fernet_store.get(collection='test', key='test') == original_value + def test_decryption_with_wrong_key_raises_error(self, memory_store: MemoryStore): """Test that decryption with the wrong key raises an error.""" fernet1 = Fernet(key=Fernet.generate_key()) fernet2 = Fernet(key=Fernet.generate_key()) - + store1 = FernetEncryptionWrapper(key_value=memory_store, fernet=fernet1) store2 = FernetEncryptionWrapper(key_value=memory_store, fernet=fernet2) - - original_value = {"test": "value"} - store1.put(collection="test", key="test", value=original_value) - + + original_value = {'test': 'value'} + store1.put(collection='test', key='test', value=original_value) + with pytest.raises(DecryptionError): - store2.get(collection="test", key="test") + store2.get(collection='test', key='test') def test_key_generation(): """Test that key generation works with a source material and salt and that different source materials and salts produce different keys.""" - - source_material = "test-source-material" - salt = "test-salt" + + source_material = 'test-source-material' + salt = 'test-salt' key = _generate_encryption_key(source_material=source_material, salt=salt) key_str_one = key.decode() - assert key_str_one == snapshot("znx7rVYt4roVgu3ymt5sIYFmfMNGEPbm8AShXQv6CY4=") - - source_material = "different-source-material" - salt = "test-salt" + assert key_str_one == snapshot('znx7rVYt4roVgu3ymt5sIYFmfMNGEPbm8AShXQv6CY4=') + + source_material = 'different-source-material' + salt = 'test-salt' key = _generate_encryption_key(source_material=source_material, salt=salt) key_str_two = key.decode() - assert key_str_two == snapshot("1TLRpjxQm4Op699i9hAXFVfyz6PqPXbuvwKaWB48tS8=") - - source_material = "test-source-material" - salt = "different-salt" + assert key_str_two == snapshot('1TLRpjxQm4Op699i9hAXFVfyz6PqPXbuvwKaWB48tS8=') + + source_material = 'test-source-material' + salt = 'different-salt' key = _generate_encryption_key(source_material=source_material, salt=salt) key_str_three = key.decode() - assert key_str_three == snapshot("oLz_g5NoLhANNh2_-ZwbgchDZ1q23VFx90kUQDjracc=") - + assert key_str_three == snapshot('oLz_g5NoLhANNh2_-ZwbgchDZ1q23VFx90kUQDjracc=') + assert key_str_one != key_str_two assert key_str_one != key_str_three assert key_str_two != key_str_three diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py index e35c1804..6a9f4b27 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py @@ -14,69 +14,76 @@ class FailingStore(MemoryStore): """A store that always fails.""" + @override - def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: - msg = "Primary store unavailable" + def get(self, key: str, *, collection: str | None=None) -> dict[str, Any] | None: + msg = 'Primary store unavailable' raise ConnectionError(msg) + @override - def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None): - msg = "Primary store unavailable" + def put(self, key: str, value: Mapping[str, Any], *, collection: str | None=None, ttl: SupportsFloat | None=None): + msg = 'Primary store unavailable' raise ConnectionError(msg) class TestFallbackWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> FallbackWrapper: fallback_store = MemoryStore() return FallbackWrapper(primary_key_value=memory_store, fallback_key_value=fallback_store) + def test_fallback_on_primary_failure(self): primary_store = FailingStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store) - + # Put data in fallback store directly - fallback_store.put(collection="test", key="test", value={"test": "fallback_value"}) - + fallback_store.put(collection='test', key='test', value={'test': 'fallback_value'}) + # Should fall back to secondary store - result = wrapper.get(collection="test", key="test") - assert result == {"test": "fallback_value"} + result = wrapper.get(collection='test', key='test') + assert result == {'test': 'fallback_value'} + def test_primary_success_no_fallback(self): primary_store = MemoryStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store) - + # Put data in primary store - primary_store.put(collection="test", key="test", value={"test": "primary_value"}) - + primary_store.put(collection='test', key='test', value={'test': 'primary_value'}) + # Put different data in fallback store - fallback_store.put(collection="test", key="test", value={"test": "fallback_value"}) - + fallback_store.put(collection='test', key='test', value={'test': 'fallback_value'}) + # Should use primary store - result = wrapper.get(collection="test", key="test") - assert result == {"test": "primary_value"} + result = wrapper.get(collection='test', key='test') + assert result == {'test': 'primary_value'} + def test_write_to_fallback_disabled(self): primary_store = FailingStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store, write_to_fallback=False) - + # Writes should fail without falling back with pytest.raises(ConnectionError): - wrapper.put(collection="test", key="test", value={"test": "value"}) + wrapper.put(collection='test', key='test', value={'test': 'value'}) + def test_write_to_fallback_enabled(self): primary_store = FailingStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store, write_to_fallback=True) - + # Writes should fall back to secondary - wrapper.put(collection="test", key="test", value={"test": "value"}) - + wrapper.put(collection='test', key='test', value={'test': 'value'}) + # Verify it was written to fallback - result = fallback_store.get(collection="test", key="test") - assert result == {"test": "value"} + result = fallback_store.get(collection='test', key='test') + assert result == {'test': 'value'} diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py index e404249a..0cd0d77f 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py @@ -11,153 +11,160 @@ class TestLimitSizeWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> LimitSizeWrapper: # Set a reasonable max size for normal test operations (20KB to handle large test strings) return LimitSizeWrapper(key_value=memory_store, max_size=20 * 1024, raise_on_too_large=False, raise_on_too_small=False) + def test_put_within_limit(self, memory_store: MemoryStore): - limit_size_store: LimitSizeWrapper = LimitSizeWrapper( - key_value=memory_store, max_size=1024, raise_on_too_large=True, raise_on_too_small=False - ) - + limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=1024, raise_on_too_large=True, raise_on_too_small=False) + # Small value should succeed - limit_size_store.put(collection="test", key="test", value={"test": "test"}) - result = limit_size_store.get(collection="test", key="test") + limit_size_store.put(collection='test', key='test', value={'test': 'test'}) + result = limit_size_store.get(collection='test', key='test') assert result is not None - assert result["test"] == "test" + assert result['test'] == 'test' + def test_put_exceeds_limit_with_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=True) - + # Large value should raise an error - large_value = {"data": "x" * 1000} + large_value = {'data': 'x' * 1000} with pytest.raises(EntryTooLargeError): - limit_size_store.put(collection="test", key="test", value=large_value) - + limit_size_store.put(collection='test', key='test', value=large_value) + # Verify nothing was stored - result = limit_size_store.get(collection="test", key="test") + result = limit_size_store.get(collection='test', key='test') assert result is None + def test_put_exceeds_limit_without_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=False) - + # Large value should be silently ignored - large_value = {"data": "x" * 1000} - limit_size_store.put(collection="test", key="test", value=large_value) - + large_value = {'data': 'x' * 1000} + limit_size_store.put(collection='test', key='test', value=large_value) + # Verify nothing was stored - result = limit_size_store.get(collection="test", key="test") + result = limit_size_store.get(collection='test', key='test') assert result is None + def test_put_below_min_size_with_raise_on_too_small(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, min_size=100, raise_on_too_small=True) - + # Small value should raise an error - small_value = {"data": "x"} + small_value = {'data': 'x'} with pytest.raises(EntryTooSmallError): - limit_size_store.put(collection="test", key="test", value=small_value) - + limit_size_store.put(collection='test', key='test', value=small_value) + # Verify nothing was stored - result = limit_size_store.get(collection="test", key="test") + result = limit_size_store.get(collection='test', key='test') assert result is None + def test_put_below_min_size_without_raise_on_too_small(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, min_size=100, raise_on_too_small=False) - + # Small value should be silently ignored - small_value = {"data": "x"} - limit_size_store.put(collection="test", key="test", value=small_value) - + small_value = {'data': 'x'} + limit_size_store.put(collection='test', key='test', value=small_value) + # Verify nothing was stored - result = limit_size_store.get(collection="test", key="test") + result = limit_size_store.get(collection='test', key='test') assert result is None + def test_put_many_mixed_sizes_with_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=True) - + # Mix of small and large values - keys = ["small1", "large1", "small2"] - values = [{"data": "x"}, {"data": "x" * 1000}, {"data": "y"}] - + keys = ['small1', 'large1', 'small2'] + values = [{'data': 'x'}, {'data': 'x' * 1000}, {'data': 'y'}] + # Should raise on the large value with pytest.raises(EntryTooLargeError): - limit_size_store.put_many(collection="test", keys=keys, values=values) - + limit_size_store.put_many(collection='test', keys=keys, values=values) + # Verify nothing was stored due to the error - results = limit_size_store.get_many(collection="test", keys=keys) + results = limit_size_store.get_many(collection='test', keys=keys) assert results[0] is None assert results[1] is None assert results[2] is None + def test_put_many_mixed_sizes_without_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=False) - + # Mix of small and large values - keys = ["small1", "large1", "small2"] - values = [{"data": "x"}, {"data": "x" * 1000}, {"data": "y"}] - + keys = ['small1', 'large1', 'small2'] + values = [{'data': 'x'}, {'data': 'x' * 1000}, {'data': 'y'}] + # Should silently filter out large value - limit_size_store.put_many(collection="test", keys=keys, values=values) - + limit_size_store.put_many(collection='test', keys=keys, values=values) + # Verify only small values were stored - results = limit_size_store.get_many(collection="test", keys=keys) - assert results[0] == {"data": "x"} + results = limit_size_store.get_many(collection='test', keys=keys) + assert results[0] == {'data': 'x'} assert results[1] is None # Large value was filtered out - assert results[2] == {"data": "y"} + assert results[2] == {'data': 'y'} + def test_put_many_with_ttl_sequence(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=False) - + # Mix of small and large values with single TTL - keys = ["small1", "large1", "small2"] - values = [{"data": "x"}, {"data": "x" * 1000}, {"data": "y"}] - + keys = ['small1', 'large1', 'small2'] + values = [{'data': 'x'}, {'data': 'x' * 1000}, {'data': 'y'}] + # Should filter out large value - limit_size_store.put_many(collection="test", keys=keys, values=values, ttl=100) - + limit_size_store.put_many(collection='test', keys=keys, values=values, ttl=100) + # Verify only small values were stored - results = limit_size_store.get_many(collection="test", keys=keys) - assert results[0] == {"data": "x"} + results = limit_size_store.get_many(collection='test', keys=keys) + assert results[0] == {'data': 'x'} assert results[1] is None # Large value was filtered out - assert results[2] == {"data": "y"} + assert results[2] == {'data': 'y'} + def test_put_many_all_too_large_without_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=10, raise_on_too_large=False) - + # All values too large - keys = ["key1", "key2"] - values = [{"data": "x" * 1000}, {"data": "y" * 1000}] - + keys = ['key1', 'key2'] + values = [{'data': 'x' * 1000}, {'data': 'y' * 1000}] + # Should not raise, but nothing should be stored - limit_size_store.put_many(collection="test", keys=keys, values=values) - + limit_size_store.put_many(collection='test', keys=keys, values=values) + # Verify nothing was stored - results = limit_size_store.get_many(collection="test", keys=keys) + results = limit_size_store.get_many(collection='test', keys=keys) assert results[0] is None assert results[1] is None + def test_exact_size_limit(self, memory_store: MemoryStore): # First, determine the exact size of a small value from key_value.shared.utils.managed_entry import ManagedEntry - - test_value = {"test": "value"} + + test_value = {'test': 'value'} managed_entry = ManagedEntry(value=test_value) json_str = managed_entry.to_json() - exact_size = len(json_str.encode("utf-8")) - + exact_size = len(json_str.encode('utf-8')) + # Create a store with exact size limit limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size, raise_on_too_large=True) - + # Should succeed at exact limit - limit_size_store.put(collection="test", key="test", value=test_value) - result = limit_size_store.get(collection="test", key="test") + limit_size_store.put(collection='test', key='test', value=test_value) + result = limit_size_store.get(collection='test', key='test') assert result == test_value - + # Should fail if one byte over - limit_size_store_under: LimitSizeWrapper = LimitSizeWrapper( - key_value=memory_store, max_size=exact_size - 1, raise_on_too_large=True - ) + limit_size_store_under: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size - 1, raise_on_too_large=True) with pytest.raises(EntryTooLargeError): - limit_size_store_under.put(collection="test", key="test2", value=test_value) + limit_size_store_under.put(collection='test', key='test2', value=test_value) diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py index 2acf806a..74961355 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py @@ -20,184 +20,102 @@ def get_messages_from_caplog(caplog: pytest.LogCaptureFixture) -> list[str]: class TestLoggingWrapper(BaseStoreTests): + @override @pytest.fixture def store(self) -> LoggingWrapper: return LoggingWrapper(key_value=MemoryStore(max_entries_per_collection=500), log_level=logging.INFO) + @override @pytest.fixture def structured_logs_store(self) -> LoggingWrapper: return LoggingWrapper(key_value=MemoryStore(max_entries_per_collection=500), log_level=logging.INFO, structured_logs=True) + @pytest.fixture def capture_logs(self, caplog: pytest.LogCaptureFixture) -> Generator[LogCaptureFixture, Any, Any]: with caplog.at_level(logging.INFO): yield caplog + def test_logging_get_operations(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): - store.get(collection="test", key="test") - assert get_messages_from_caplog(capture_logs) == snapshot( - ["Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' ({'hit': False})"] - ) - + store.get(collection='test', key='test') + assert get_messages_from_caplog(capture_logs) == snapshot(["Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' ({'hit': False})"]) + capture_logs.clear() - - structured_logs_store.get(collection="test", key="test") - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', - '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}', - ] - ) - + + structured_logs_store.get(collection='test', key='test') + assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}']) + capture_logs.clear() - - store.get_many(collection="test", keys=["test", "test_2"]) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - "Start GET_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", - "Finish GET_MANY collection='test' keys='['test', 'test_2']' ({'hits': 0, 'misses': 2})", - ] - ) - + + store.get_many(collection='test', keys=['test', 'test_2']) + assert get_messages_from_caplog(capture_logs) == snapshot(["Start GET_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", "Finish GET_MANY collection='test' keys='['test', 'test_2']' ({'hits': 0, 'misses': 2})"]) + capture_logs.clear() - - structured_logs_store.get_many(collection="test", keys=["test", "test_2"]) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - '{"status": "start", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', - '{"status": "finish", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"hits": 0, "misses": 2}}', - ] - ) + + structured_logs_store.get_many(collection='test', keys=['test', 'test_2']) + assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', '{"status": "finish", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"hits": 0, "misses": 2}}']) + def test_logging_put_operations(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): logging_store = LoggingWrapper(key_value=store, log_level=logging.INFO) - - logging_store.put(collection="test", key="test", value={"test": "value"}) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", - "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", - "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", - "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", - ] - ) - + + logging_store.put(collection='test', key='test', value={'test': 'value'}) + assert get_messages_from_caplog(capture_logs) == snapshot(["Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})"]) + capture_logs.clear() - - structured_logs_store.put(collection="test", key="test", value={"test": "value"}) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - '{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', - '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', - ] - ) - + + structured_logs_store.put(collection='test', key='test', value={'test': 'value'}) + assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}']) + capture_logs.clear() - - logging_store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "value"}, {"test": "value_2"}]) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - "Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", - "Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", - "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", - "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", - ] - ) - + + logging_store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'value'}, {'test': 'value_2'}]) + assert get_messages_from_caplog(capture_logs) == snapshot(["Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", "Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})"]) + capture_logs.clear() - - structured_logs_store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "value"}, {"test": "value_2"}]) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - '{"status": "start", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}', - '{"status": "finish", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}', - ] - ) + + structured_logs_store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'value'}, {'test': 'value_2'}]) + assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}', '{"status": "finish", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}']) + def test_logging_delete_operations(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): logging_store = LoggingWrapper(key_value=store, log_level=logging.INFO) - - logging_store.delete(collection="test", key="test") - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - "Start DELETE collection='test' keys='test'", - "Start DELETE collection='test' keys='test'", - "Finish DELETE collection='test' keys='test' ({'deleted': False})", - "Finish DELETE collection='test' keys='test' ({'deleted': False})", - ] - ) - + + logging_store.delete(collection='test', key='test') + assert get_messages_from_caplog(capture_logs) == snapshot(["Start DELETE collection='test' keys='test'", "Start DELETE collection='test' keys='test'", "Finish DELETE collection='test' keys='test' ({'deleted': False})", "Finish DELETE collection='test' keys='test' ({'deleted': False})"]) + capture_logs.clear() - - structured_logs_store.delete(collection="test", key="test") - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - '{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', - '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": false}}', - ] - ) - + + structured_logs_store.delete(collection='test', key='test') + assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": false}}']) + capture_logs.clear() - - logging_store.delete_many(collection="test", keys=["test", "test_2"]) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - "Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", - "Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", - "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})", - "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})", - ] - ) - + + logging_store.delete_many(collection='test', keys=['test', 'test_2']) + assert get_messages_from_caplog(capture_logs) == snapshot(["Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", "Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})", "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})"]) + capture_logs.clear() - - structured_logs_store.delete_many(collection="test", keys=["test", "test_2"]) - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - '{"status": "start", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', - '{"status": "finish", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"deleted": 0}}', - ] - ) - - def test_put_get_delete_get_logging( - self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture - ): - store.put(collection="test", key="test", value={"test": "value"}) - assert store.get(collection="test", key="test") == {"test": "value"} - assert store.delete(collection="test", key="test") - assert store.get(collection="test", key="test") is None - - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", - "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", - "Start GET collection='test' keys='test'", - "Finish GET collection='test' keys='test' value={'test': 'value'} ({'hit': True})", - "Start DELETE collection='test' keys='test'", - "Finish DELETE collection='test' keys='test' ({'deleted': True})", - "Start GET collection='test' keys='test'", - "Finish GET collection='test' keys='test' ({'hit': False})", - ] - ) - + + structured_logs_store.delete_many(collection='test', keys=['test', 'test_2']) + assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', '{"status": "finish", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"deleted": 0}}']) + + + def test_put_get_delete_get_logging(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): + store.put(collection='test', key='test', value={'test': 'value'}) + assert store.get(collection='test', key='test') == {'test': 'value'} + assert store.delete(collection='test', key='test') + assert store.get(collection='test', key='test') is None + + assert get_messages_from_caplog(capture_logs) == snapshot(["Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' value={'test': 'value'} ({'hit': True})", "Start DELETE collection='test' keys='test'", "Finish DELETE collection='test' keys='test' ({'deleted': True})", "Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' ({'hit': False})"]) + capture_logs.clear() - - structured_logs_store.put(collection="test", key="test", value={"test": "value"}) - assert structured_logs_store.get(collection="test", key="test") == {"test": "value"} - assert structured_logs_store.delete(collection="test", key="test") - assert structured_logs_store.get(collection="test", key="test") is None - - assert get_messages_from_caplog(capture_logs) == snapshot( - [ - '{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', - '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', - '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', - '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"hit": true}}', - '{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', - '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": true}}', - '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', - '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}', - ] - ) + + structured_logs_store.put(collection='test', key='test', value={'test': 'value'}) + assert structured_logs_store.get(collection='test', key='test') == {'test': 'value'} + assert structured_logs_store.delete(collection='test', key='test') + assert structured_logs_store.get(collection='test', key='test') is None + + assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"hit": true}}', '{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": true}}', '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}']) diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py index f71fce63..529e39fd 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py @@ -16,14 +16,18 @@ class TestPassthroughCacheWrapper(BaseStoreTests): - @pytest.fixture(scope="session") + + @pytest.fixture(scope='session') def primary_store(self) -> Generator[DiskStore, None, None]: - with tempfile.TemporaryDirectory() as temp_dir, DiskStore(directory=temp_dir, max_size=DISK_STORE_SIZE_LIMIT) as disk_store: - yield disk_store + with tempfile.TemporaryDirectory() as temp_dir: + with DiskStore(directory=temp_dir, max_size=DISK_STORE_SIZE_LIMIT) as disk_store: + yield disk_store + @pytest.fixture def cache_store(self, memory_store: MemoryStore) -> MemoryStore: return memory_store + @override @pytest.fixture diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py index 6a31f566..efe487c2 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py @@ -10,7 +10,8 @@ class TestPrefixCollectionWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> PrefixCollectionsWrapper: - return PrefixCollectionsWrapper(key_value=memory_store, prefix="collection_prefix") + return PrefixCollectionsWrapper(key_value=memory_store, prefix='collection_prefix') diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py index 8d64c8d2..77e37107 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py @@ -10,7 +10,8 @@ class TestPrefixKeyWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> PrefixKeysWrapper: - return PrefixKeysWrapper(key_value=memory_store, prefix="key_prefix") + return PrefixKeysWrapper(key_value=memory_store, prefix='key_prefix') diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py index 56e4c959..6a0f94fb 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py @@ -10,63 +10,68 @@ class TestReadOnlyWrapper: + @pytest.fixture def memory_store(self) -> MemoryStore: return MemoryStore() + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> ReadOnlyWrapper: # Pre-populate the store with test data - memory_store.put(collection="test", key="test", value={"test": "test"}) + memory_store.put(collection='test', key='test', value={'test': 'test'}) return ReadOnlyWrapper(key_value=memory_store, raise_on_write=False) + def test_read_operations_allowed(self, memory_store: MemoryStore): # Pre-populate store - memory_store.put(collection="test", key="test", value={"test": "value"}) - + memory_store.put(collection='test', key='test', value={'test': 'value'}) + read_only_store = ReadOnlyWrapper(key_value=memory_store, raise_on_write=True) - + # Read operations should work - result = read_only_store.get(collection="test", key="test") - assert result == {"test": "value"} - - results = read_only_store.get_many(collection="test", keys=["test"]) - assert results == [{"test": "value"}] - - (value, _) = read_only_store.ttl(collection="test", key="test") - assert value == {"test": "value"} + result = read_only_store.get(collection='test', key='test') + assert result == {'test': 'value'} + + results = read_only_store.get_many(collection='test', keys=['test']) + assert results == [{'test': 'value'}] + + (value, _) = read_only_store.ttl(collection='test', key='test') + assert value == {'test': 'value'} + def test_write_operations_raise_error(self, memory_store: MemoryStore): read_only_store = ReadOnlyWrapper(key_value=memory_store, raise_on_write=True) - + # Write operations should raise ReadOnlyError with pytest.raises(ReadOnlyError): - read_only_store.put(collection="test", key="test", value={"test": "value"}) - + read_only_store.put(collection='test', key='test', value={'test': 'value'}) + with pytest.raises(ReadOnlyError): - read_only_store.put_many(collection="test", keys=["test"], values=[{"test": "value"}]) - + read_only_store.put_many(collection='test', keys=['test'], values=[{'test': 'value'}]) + with pytest.raises(ReadOnlyError): - read_only_store.delete(collection="test", key="test") - + read_only_store.delete(collection='test', key='test') + with pytest.raises(ReadOnlyError): - read_only_store.delete_many(collection="test", keys=["test"]) + read_only_store.delete_many(collection='test', keys=['test']) + def test_write_operations_silent_ignore(self, memory_store: MemoryStore): read_only_store = ReadOnlyWrapper(key_value=memory_store, raise_on_write=False) - + # Write operations should be silently ignored - read_only_store.put(collection="test", key="new_key", value={"test": "value"}) - + read_only_store.put(collection='test', key='new_key', value={'test': 'value'}) + # Verify nothing was written - result = memory_store.get(collection="test", key="new_key") + result = memory_store.get(collection='test', key='new_key') assert result is None - + # Delete should return False - deleted = read_only_store.delete(collection="test", key="test") + deleted = read_only_store.delete(collection='test', key='test') assert deleted is False - + # Delete many should return 0 - deleted_count = read_only_store.delete_many(collection="test", keys=["test"]) + deleted_count = read_only_store.delete_many(collection='test', keys=['test']) assert deleted_count == 0 diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py index aa12b5b8..267414cf 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py @@ -11,67 +11,75 @@ class FailingStore(MemoryStore): """A store that fails a certain number of times before succeeding.""" + - def __init__(self, failures_before_success: int = 2): + def __init__(self, failures_before_success: int=2): super().__init__() self.failures_before_success = failures_before_success self.attempt_count = 0 + - def get(self, key: str, *, collection: str | None = None): + def get(self, key: str, *, collection: str | None=None): self.attempt_count += 1 if self.attempt_count <= self.failures_before_success: - msg = "Simulated connection error" + msg = 'Simulated connection error' raise ConnectionError(msg) return super().get(key=key, collection=collection) + def reset_attempts(self): self.attempt_count = 0 class TestRetryWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> RetryWrapper: return RetryWrapper(key_value=memory_store, max_retries=3, initial_delay=0.01) + def test_retry_succeeds_after_failures(self): failing_store = FailingStore(failures_before_success=2) retry_store = RetryWrapper(key_value=failing_store, max_retries=3, initial_delay=0.01) - + # Store a value first - retry_store.put(collection="test", key="test", value={"test": "value"}) + retry_store.put(collection='test', key='test', value={'test': 'value'}) failing_store.reset_attempts() - + # Should succeed after 2 failures - result = retry_store.get(collection="test", key="test") - assert result == {"test": "value"} + result = retry_store.get(collection='test', key='test') + assert result == {'test': 'value'} assert failing_store.attempt_count == 3 # 2 failures + 1 success + def test_retry_fails_after_max_retries(self): failing_store = FailingStore(failures_before_success=10) # More failures than max_retries retry_store = RetryWrapper(key_value=failing_store, max_retries=2, initial_delay=0.01) - + # Should fail after exhausting retries with pytest.raises(ConnectionError): - retry_store.get(collection="test", key="test") - + retry_store.get(collection='test', key='test') + assert failing_store.attempt_count == 3 # Initial attempt + 2 retries + def test_retry_with_different_exception(self): failing_store = FailingStore(failures_before_success=1) # Only retry on TimeoutError, not ConnectionError retry_store = RetryWrapper(key_value=failing_store, max_retries=3, initial_delay=0.01, retry_on=(TimeoutError,)) - + # Should fail immediately without retries with pytest.raises(ConnectionError): - retry_store.get(collection="test", key="test") - + retry_store.get(collection='test', key='test') + assert failing_store.attempt_count == 1 # No retries + def test_retry_no_failures(self, memory_store: MemoryStore): retry_store = RetryWrapper(key_value=memory_store, max_retries=3, initial_delay=0.01) - + # Normal operation should work without retries - retry_store.put(collection="test", key="test", value={"test": "value"}) - result = retry_store.get(collection="test", key="test") - assert result == {"test": "value"} + retry_store.put(collection='test', key='test', value={'test': 'value'}) + result = retry_store.get(collection='test', key='test') + assert result == {'test': 'value'} diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py index 9a75430d..5494eb27 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py @@ -11,33 +11,37 @@ from key_value.sync.code_gen.wrappers.routing import CollectionRoutingWrapper, RoutingWrapper from tests.code_gen.stores.base import BaseStoreTests -KEY_ONE = "key1" -VALUE_ONE = {"this_key_1": "this_value_1"} -COLLECTION_ONE = "first" +KEY_ONE = 'key1' +VALUE_ONE = {'this_key_1': 'this_value_1'} +COLLECTION_ONE = 'first' -KEY_TWO = "key2" -VALUE_TWO = {"this_key_2": "this_value_2"} -COLLECTION_TWO = "second" +KEY_TWO = 'key2' +VALUE_TWO = {'this_key_2': 'this_value_2'} +COLLECTION_TWO = 'second' -KEY_UNMAPPED = "key3" -VALUE_UNMAPPED = {"this_key_3": "this_value_3"} -COLLECTION_UNMAPPED = "unmapped" +KEY_UNMAPPED = 'key3' +VALUE_UNMAPPED = {'this_key_3': 'this_value_3'} +COLLECTION_UNMAPPED = 'unmapped' ALL_KEYS = [KEY_ONE, KEY_TWO, KEY_UNMAPPED] class TestRoutingWrapper(BaseStoreTests): + @pytest.fixture def second_store(self) -> MemoryStore: return MemoryStore() + @pytest.fixture def default_store(self) -> MemoryStore: return MemoryStore() + @pytest.fixture def store(self, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore) -> RoutingWrapper: first_store = memory_store + def route(collection: str | None) -> KeyValue | None: if collection == COLLECTION_ONE: @@ -45,8 +49,9 @@ def route(collection: str | None) -> KeyValue | None: if collection == COLLECTION_TWO: return second_store return None - + return RoutingWrapper(routing_function=route, default_store=default_store) + @pytest.fixture def store_with_data(self, store: RoutingWrapper) -> RoutingWrapper: @@ -54,74 +59,74 @@ def store_with_data(self, store: RoutingWrapper) -> RoutingWrapper: store.put(key=KEY_TWO, value=VALUE_TWO, collection=COLLECTION_TWO) store.put(key=KEY_UNMAPPED, value=VALUE_UNMAPPED, collection=COLLECTION_UNMAPPED) return store + @override - @pytest.mark.skip(reason="RoutingWrapper is unbounded") - def test_not_unbounded(self, store: BaseStore): ... + @pytest.mark.skip(reason='RoutingWrapper is unbounded') + def test_not_unbounded(self, store: BaseStore): + ... + - def test_routing_get_and_get_many( - self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore - ): + def test_routing_get_and_get_many(self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore): """Test basic routing sends gets""" assert memory_store.get(key=KEY_ONE, collection=COLLECTION_ONE) == VALUE_ONE assert memory_store.get(key=KEY_TWO, collection=COLLECTION_TWO) is None assert memory_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [VALUE_ONE, None, None] - + assert second_store.get(key=KEY_ONE, collection=COLLECTION_ONE) is None assert second_store.get(key=KEY_TWO, collection=COLLECTION_TWO) == VALUE_TWO assert second_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) is None assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, VALUE_TWO, None] - + assert default_store.get(key=KEY_ONE, collection=COLLECTION_ONE) is None assert default_store.get(key=KEY_TWO, collection=COLLECTION_TWO) is None assert default_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) == VALUE_UNMAPPED assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, VALUE_UNMAPPED] + - def test_routing_delete( - self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore - ): + def test_routing_delete(self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore): """Test delete operations route correctly.""" - + assert store_with_data.get(key=KEY_ONE, collection=COLLECTION_ONE) == VALUE_ONE store_with_data.delete(key=KEY_ONE, collection=COLLECTION_ONE) assert memory_store.get(key=KEY_ONE, collection=COLLECTION_ONE) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [None, None, None] assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [None, None, None] assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [None, None, None] - + assert store_with_data.get(key=KEY_TWO, collection=COLLECTION_TWO) == VALUE_TWO store_with_data.delete(key=KEY_TWO, collection=COLLECTION_TWO) assert memory_store.get(key=KEY_TWO, collection=COLLECTION_TWO) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, None, None] assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, None, None] assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, None, None] - + assert store_with_data.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) == VALUE_UNMAPPED store_with_data.delete(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) assert memory_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, None] assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, None] assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, None] + def test_routing_ttl(self, store: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore): """Test TTL operations route correctly.""" key_one_ttl = 1800 key_two_ttl = 2700 key_unmapped_ttl = 7200 - + store.put(key=KEY_ONE, value=VALUE_ONE, collection=COLLECTION_ONE, ttl=key_one_ttl) store.put(key=KEY_TWO, value=VALUE_TWO, collection=COLLECTION_TWO, ttl=key_two_ttl) store.put(key=KEY_UNMAPPED, value=VALUE_UNMAPPED, collection=COLLECTION_UNMAPPED, ttl=key_unmapped_ttl) - + assert store.ttl(key=KEY_ONE, collection=COLLECTION_ONE) == (VALUE_ONE, IsFloat(approx=key_one_ttl)) assert store.ttl(key=KEY_TWO, collection=COLLECTION_TWO) == (VALUE_TWO, IsFloat(approx=key_two_ttl)) assert store.ttl(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) == (VALUE_UNMAPPED, IsFloat(approx=key_unmapped_ttl)) class TestCollectionRoutingWrapper(TestRoutingWrapper): + @pytest.fixture def store(self, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore) -> CollectionRoutingWrapper: - return CollectionRoutingWrapper( - collection_map={COLLECTION_ONE: memory_store, COLLECTION_TWO: second_store}, default_store=default_store - ) + return CollectionRoutingWrapper(collection_map={COLLECTION_ONE: memory_store, COLLECTION_TWO: second_store}, default_store=default_store) diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py index f4e70e52..c8ee4512 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py @@ -10,7 +10,8 @@ class TestSingleCollectionWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> SingleCollectionWrapper: - return SingleCollectionWrapper(key_value=memory_store, single_collection="test") + return SingleCollectionWrapper(key_value=memory_store, single_collection='test') diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py index ddf43f0b..6dfe9e39 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py @@ -10,6 +10,7 @@ class TestStatisticsWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> StatisticsWrapper: diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py index 59ac6473..175ffbbf 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py @@ -11,41 +11,45 @@ class TestTTLClampWrapper(BaseStoreTests): + @override @pytest.fixture def store(self, memory_store: MemoryStore) -> TTLClampWrapper: return TTLClampWrapper(key_value=memory_store, min_ttl=0, max_ttl=100) + def test_put_below_min_ttl(self, memory_store: MemoryStore): ttl_clamp_store: TTLClampWrapper = TTLClampWrapper(key_value=memory_store, min_ttl=50, max_ttl=100) - - ttl_clamp_store.put(collection="test", key="test", value={"test": "test"}, ttl=5) - assert ttl_clamp_store.get(collection="test", key="test") is not None - - (value, ttl) = ttl_clamp_store.ttl(collection="test", key="test") + + ttl_clamp_store.put(collection='test', key='test', value={'test': 'test'}, ttl=5) + assert ttl_clamp_store.get(collection='test', key='test') is not None + + (value, ttl) = ttl_clamp_store.ttl(collection='test', key='test') assert value is not None assert ttl is not None assert ttl == IsFloat(approx=50) + def test_put_above_max_ttl(self, memory_store: MemoryStore): ttl_clamp_store: TTLClampWrapper = TTLClampWrapper(key_value=memory_store, min_ttl=0, max_ttl=100) - - ttl_clamp_store.put(collection="test", key="test", value={"test": "test"}, ttl=1000) - assert ttl_clamp_store.get(collection="test", key="test") is not None - - (value, ttl) = ttl_clamp_store.ttl(collection="test", key="test") + + ttl_clamp_store.put(collection='test', key='test', value={'test': 'test'}, ttl=1000) + assert ttl_clamp_store.get(collection='test', key='test') is not None + + (value, ttl) = ttl_clamp_store.ttl(collection='test', key='test') assert value is not None assert ttl is not None assert ttl == IsFloat(approx=100) + def test_put_missing_ttl(self, memory_store: MemoryStore): ttl_clamp_store: TTLClampWrapper = TTLClampWrapper(key_value=memory_store, min_ttl=0, max_ttl=100, missing_ttl=50) - - ttl_clamp_store.put(collection="test", key="test", value={"test": "test"}, ttl=None) - assert ttl_clamp_store.get(collection="test", key="test") is not None - - (value, ttl) = ttl_clamp_store.ttl(collection="test", key="test") + + ttl_clamp_store.put(collection='test', key='test', value={'test': 'test'}, ttl=None) + assert ttl_clamp_store.get(collection='test', key='test') is not None + + (value, ttl) = ttl_clamp_store.ttl(collection='test', key='test') assert value is not None assert ttl is not None - + assert ttl == IsFloat(approx=50) From 0ab6766915eb879a7bf51321069e36520992fbeb Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:52:08 +0000 Subject: [PATCH 02/11] refactor: consolidate serialization adapters into shared module - Create shared SerializationAdapter ABC in key_value.shared.utils.serialization - Add common adapter implementations: - FullJsonAdapter: For Redis/Valkey (full JSON strings) - StringifiedDictAdapter: For stores preferring dict with stringified values - MongoDBAdapter: With native BSON datetime support for TTL indexing - ElasticsearchAdapter: With native/string storage mode support - Update MongoDB, Redis, and Elasticsearch stores to use shared adapters - Remove duplicate SerializationAdapter definitions from each store - Remove store-specific helper functions (managed_entry_to_document, etc.) - Run codegen to sync changes to sync library - All linting passes This eliminates code duplication and provides a consistent serialization interface across all stores, making it easier to add new stores or modify serialization behavior. Co-authored-by: William Easton --- .../aio/stores/elasticsearch/store.py | 74 +--- .../src/key_value/aio/stores/mongodb/store.py | 218 +---------- .../src/key_value/aio/stores/redis/store.py | 93 +---- .../key_value/shared/utils/serialization.py | 346 ++++++++++++++++++ .../key_value/sync/code_gen/stores/base.py | 1 - .../code_gen/stores/elasticsearch/store.py | 70 +--- .../sync/code_gen/stores/mongodb/store.py | 218 +---------- .../sync/code_gen/stores/redis/store.py | 93 +---- .../tests/code_gen/adapters/test_dataclass.py | 94 +++-- .../tests/code_gen/adapters/test_pydantic.py | 121 +++--- .../tests/code_gen/adapters/test_raise.py | 38 +- .../key-value-sync/tests/code_gen/cases.py | 50 ++- .../key-value-sync/tests/code_gen/conftest.py | 89 ++--- .../tests/code_gen/protocols/test_types.py | 19 +- .../tests/code_gen/stores/base.py | 285 +++++++-------- .../tests/code_gen/stores/conftest.py | 2 +- .../tests/code_gen/stores/disk/test_disk.py | 6 +- .../code_gen/stores/disk/test_multi_disk.py | 6 +- .../elasticsearch/test_elasticsearch.py | 221 ++++++----- .../code_gen/stores/keyring/test_keyring.py | 20 +- .../code_gen/stores/memory/test_memory.py | 6 +- .../code_gen/stores/mongodb/test_mongodb.py | 170 +++++---- .../tests/code_gen/stores/redis/test_redis.py | 58 ++- .../code_gen/stores/rocksdb/test_rocksdb.py | 43 +-- .../code_gen/stores/simple/test_store.py | 1 - .../code_gen/stores/valkey/test_valkey.py | 44 +-- .../tests/code_gen/stores/vault/test_vault.py | 57 ++- .../windows_registry/test_windows_registry.py | 24 +- .../stores/wrappers/test_compression.py | 110 +++--- .../stores/wrappers/test_default_value.py | 82 ++--- .../stores/wrappers/test_encryption.py | 154 ++++---- .../code_gen/stores/wrappers/test_fallback.py | 53 ++- .../stores/wrappers/test_limit_size.py | 149 ++++---- .../code_gen/stores/wrappers/test_logging.py | 216 +++++++---- .../stores/wrappers/test_passthrough_cache.py | 10 +- .../stores/wrappers/test_prefix_collection.py | 3 +- .../stores/wrappers/test_prefix_key.py | 3 +- .../stores/wrappers/test_read_only.py | 61 ++- .../code_gen/stores/wrappers/test_retry.py | 44 +-- .../code_gen/stores/wrappers/test_routing.py | 61 ++- .../stores/wrappers/test_single_collection.py | 3 +- .../stores/wrappers/test_statistics.py | 1 - .../stores/wrappers/test_ttl_clamp.py | 36 +- 43 files changed, 1568 insertions(+), 1885 deletions(-) create mode 100644 key-value/key-value-shared/src/key_value/shared/utils/serialization.py diff --git a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py index ad58aca0..740c74f4 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py @@ -6,14 +6,15 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ( ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, NUMBERS, sanitize_string, ) -from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str +from key_value.shared.utils.serialization import ElasticsearchAdapter, SerializationAdapter +from key_value.shared.utils.time_to_live import now_as_epoch from typing_extensions import override from key_value.aio.stores.base import ( @@ -84,54 +85,6 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." -def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]: - document: dict[str, Any] = {"collection": collection, "key": key, "value": {}} - - # Store in appropriate field based on mode - if native_storage: - document["value"]["flattened"] = managed_entry.value_as_dict - else: - document["value"]["string"] = managed_entry.value_as_json - - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at.isoformat() - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at.isoformat() - - return document - - -def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry: - value: dict[str, Any] = {} - - raw_value = source.get("value") - - # Try flattened field first, fall back to string field - if not raw_value or not isinstance(raw_value, dict): - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - value = verify_dict(obj=value_flattened) - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - if not isinstance(value_str, str): - msg = "Value in `value` field is not a string" - raise DeserializationError(msg) - value = load_from_json(value_str) - else: - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at")) - - return ManagedEntry( - value=value, - created_at=created_at, - expires_at=expires_at, - ) - - class ElasticsearchStore( BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyCollectionStore, BaseCullStore, BaseContextManagerStore, BaseStore ): @@ -145,6 +98,8 @@ class ElasticsearchStore( _native_storage: bool + _adapter: SerializationAdapter + @overload def __init__( self, @@ -208,6 +163,7 @@ def __init__( self._index_prefix = index_prefix self._native_storage = native_storage self._is_serverless = False + self._adapter = ElasticsearchAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -260,7 +216,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None try: - return source_to_managed_entry(source=source) + return self._adapter.from_storage(data=source) except DeserializationError: return None @@ -293,7 +249,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> continue try: - entries_by_id[doc_id] = source_to_managed_entry(source=source) + entries_by_id[doc_id] = self._adapter.from_storage(data=source) except DeserializationError as e: logger.error( "Failed to deserialize Elasticsearch document in batch operation", @@ -324,9 +280,10 @@ async def _put_managed_entry( index_name: str = self._sanitize_index_name(collection=collection) document_id: str = self._sanitize_document_id(key=key) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + if not isinstance(document, dict): + msg = "Elasticsearch adapter must return dict" + raise TypeError(msg) try: _ = await self._client.index( @@ -364,9 +321,10 @@ async def _put_managed_entries( index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + if not isinstance(document, dict): + msg = "Elasticsearch adapter must return dict" + raise TypeError(msg) operations.extend([index_action, document]) try: diff --git a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py index 84f19957..b1a3a8c0 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py @@ -1,12 +1,11 @@ -from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string -from key_value.shared.utils.time_to_live import timezone +from key_value.shared.utils.serialization import MongoDBAdapter, SerializationAdapter from typing_extensions import Self, override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore @@ -36,219 +35,6 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" -class SerializationAdapter(ABC): - """Base class for store-specific serialization adapters. - - Adapters encapsulate the logic for converting between ManagedEntry objects - and store-specific storage formats. This provides a consistent interface - while allowing each store to optimize its serialization strategy. - """ - - @abstractmethod - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: - """Convert a ManagedEntry to the store's storage format. - - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: Optional collection name. - - Returns: - The serialized representation (dict or str depending on store). - """ - ... - - @abstractmethod - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert stored data back to a ManagedEntry. - - Args: - data: The stored representation to deserialize. - - Returns: - A ManagedEntry reconstructed from storage. - - Raises: - DeserializationError: If the data cannot be deserialized. - """ - ... - - -class MongoDBAdapter(SerializationAdapter): - """MongoDB-specific serialization adapter. - - Stores entries with native BSON datetime types for TTL indexing, - while maintaining the value.object/value.string structure for compatibility. - """ - - def __init__(self, *, native_storage: bool = True) -> None: - """Initialize the MongoDB adapter. - - Args: - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. - """ - self.native_storage = native_storage - - @override - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: - """Convert a ManagedEntry to a MongoDB document.""" - document: dict[str, Any] = {"key": key, "value": {}} - - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores - json_str = entry.value_as_json - - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["object"] = entry.value_as_dict - else: - document["value"]["string"] = json_str - - # Add metadata fields as BSON datetimes for TTL indexing - if entry.created_at: - document["created_at"] = entry.created_at - if entry.expires_at: - document["expires_at"] = entry.expires_at - - return document - - @override - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry.""" - if not isinstance(data, dict): - msg = "Expected MongoDB document to be a dict" - raise DeserializationError(msg) - - document = data - - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - entry_data: dict[str, Any] = {} - - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): - msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): - msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - # Support both native (object) and legacy (string) storage - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) - - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) - - -def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry. - - This function deserializes a MongoDB document (created by `managed_entry_to_document`) back to a - ManagedEntry object. It supports both native BSON storage (dict in value.object field) and legacy - JSON string storage (string in value.string field) for migration support. - - Args: - document: The MongoDB document to convert. - - Returns: - A ManagedEntry object reconstructed from the document. - """ - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - data: dict[str, Any] = {} - - # The Value field is an object with two possible fields: `object` and `string` - # - `object`: The value is a native BSON dict - # - `string`: The value is a JSON string - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): - msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): - msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **data}, stringified_value=True) - - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) - - -def managed_entry_to_document(key: str, managed_entry: ManagedEntry, *, native_storage: bool = True) -> dict[str, Any]: - """Convert a ManagedEntry to a MongoDB document for storage. - - This function serializes a ManagedEntry to a MongoDB document format, including the key and all - metadata (TTL, creation, and expiration timestamps). The value storage format depends on the - native_storage parameter. - - Args: - key: The key associated with this entry. - managed_entry: The ManagedEntry to serialize. - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. - - Returns: - A MongoDB document dict containing the key, value, and all metadata. - """ - document: dict[str, Any] = {"key": key, "value": {}} - - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores. For example, - # Mongo can natively handle datetime objects which other stores cannot, if we don't convert to JSON, - # then using py-key-value with Mongo will return different values than if we used another store. - json_str = managed_entry.value_as_json - - # Store in appropriate field based on mode - if native_storage: - document["value"]["object"] = managed_entry.value_as_dict - else: - document["value"]["string"] = json_str - - # Add metadata fields - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at - - return document - - class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): """MongoDB-based key-value store using pymongo.""" diff --git a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py index 6d060ecc..d8bee19a 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload @@ -8,6 +7,7 @@ from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import FullJsonAdapter, SerializationAdapter from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -22,95 +22,6 @@ PAGE_LIMIT = 10000 -class SerializationAdapter(ABC): - """Base class for store-specific serialization adapters. - - Adapters encapsulate the logic for converting between ManagedEntry objects - and store-specific storage formats. - """ - - @abstractmethod - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: - """Convert a ManagedEntry to the store's storage format. - - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: Optional collection name. - - Returns: - The serialized representation (dict or str depending on store). - """ - ... - - @abstractmethod - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert stored data back to a ManagedEntry. - - Args: - data: The stored representation to deserialize. - - Returns: - A ManagedEntry reconstructed from storage. - - Raises: - DeserializationError: If the data cannot be deserialized. - """ - ... - - -class RedisAdapter(SerializationAdapter): - """Redis-specific serialization adapter. - - Stores entries as JSON strings in Redis. - """ - - @override - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: - """Convert a ManagedEntry to a JSON string for Redis storage.""" - return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - @override - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a JSON string from Redis storage back to a ManagedEntry.""" - if not isinstance(data, str): - msg = "Expected Redis data to be a string" - raise DeserializationError(msg) - - return ManagedEntry.from_json(json_str=data, includes_metadata=True) - - -def managed_entry_to_json(managed_entry: ManagedEntry) -> str: - """Convert a ManagedEntry to a JSON string for Redis storage. - - This function serializes a ManagedEntry to JSON format including all metadata (TTL, creation, - and expiration timestamps). The serialization is designed to preserve all entry information - for round-trip conversion back to a ManagedEntry. - - Args: - managed_entry: The ManagedEntry to serialize. - - Returns: - A JSON string representation of the ManagedEntry with full metadata. - """ - return managed_entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - -def json_to_managed_entry(json_str: str) -> ManagedEntry: - """Convert a JSON string from Redis storage back to a ManagedEntry. - - This function deserializes a JSON string (created by `managed_entry_to_json`) back to a - ManagedEntry object, preserving all metadata including TTL, creation, and expiration timestamps. - - Args: - json_str: The JSON string to deserialize. - - Returns: - A ManagedEntry object reconstructed from the JSON string. - """ - return ManagedEntry.from_json(json_str=json_str, includes_metadata=True) - - class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" @@ -172,7 +83,7 @@ def __init__( ) self._stable_api = True - self._adapter = RedisAdapter() + self._adapter = FullJsonAdapter() super().__init__(default_collection=default_collection) diff --git a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py new file mode 100644 index 00000000..86649455 --- /dev/null +++ b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py @@ -0,0 +1,346 @@ +"""Serialization adapters for converting ManagedEntry objects to/from store-specific formats. + +This module provides a base SerializationAdapter ABC and common adapter implementations +that can be reused across different key-value stores. +""" + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any + +from key_value.shared.errors.key_value import DeserializationError +from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict +from key_value.shared.utils.time_to_live import try_parse_datetime_str + + +class SerializationAdapter(ABC): + """Base class for store-specific serialization adapters. + + Adapters encapsulate the logic for converting between ManagedEntry objects + and store-specific storage formats. This provides a consistent interface + while allowing each store to optimize its serialization strategy. + """ + + @abstractmethod + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: + """Convert a ManagedEntry to the store's storage format. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: Optional collection name. + + Returns: + The serialized representation (dict or str depending on store). + """ + ... + + @abstractmethod + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert stored data back to a ManagedEntry. + + Args: + data: The stored representation to deserialize. + + Returns: + A ManagedEntry reconstructed from storage. + + Raises: + DeserializationError: If the data cannot be deserialized. + """ + ... + + +class FullJsonAdapter(SerializationAdapter): + """Adapter that serializes entries as complete JSON strings. + + This adapter is suitable for stores that work with string values, + such as Redis or Valkey. It serializes the entire ManagedEntry + (including all metadata) to a JSON string. + """ + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: # noqa: ARG002 + """Convert a ManagedEntry to a JSON string. + + Args: + key: The key (unused, for interface compatibility). + entry: The ManagedEntry to serialize. + collection: The collection (unused, for interface compatibility). + + Returns: + A JSON string containing the entry and all metadata. + """ + return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a JSON string back to a ManagedEntry. + + Args: + data: The JSON string to deserialize. + + Returns: + A ManagedEntry reconstructed from the JSON. + + Raises: + DeserializationError: If data is not a string or cannot be parsed. + """ + if not isinstance(data, str): + msg = "Expected data to be a JSON string" + raise DeserializationError(msg) + + return ManagedEntry.from_json(json_str=data, includes_metadata=True) + + +class StringifiedDictAdapter(SerializationAdapter): + """Adapter that serializes entries as dicts with stringified values. + + This adapter is suitable for stores that prefer to store entries as + documents with the value field serialized as a JSON string. This allows + stores to index and query metadata fields while keeping the value opaque. + """ + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: # noqa: ARG002 + """Convert a ManagedEntry to a dict with stringified value. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: The collection (unused, for interface compatibility). + + Returns: + A dict with key, stringified value, and metadata fields. + """ + return { + "key": key, + **entry.to_dict(include_metadata=True, include_expiration=True, include_creation=True, stringify_value=True), + } + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a dict with stringified value back to a ManagedEntry. + + Args: + data: The dict to deserialize. + + Returns: + A ManagedEntry reconstructed from the dict. + + Raises: + DeserializationError: If data is not a dict or is malformed. + """ + if not isinstance(data, dict): + msg = "Expected data to be a dict" + raise DeserializationError(msg) + + return ManagedEntry.from_dict(obj=data, expects_stringified_value=True, includes_metadata=True) + + +class MongoDBAdapter(SerializationAdapter): + """MongoDB-specific serialization adapter with native BSON datetime support. + + This adapter is optimized for MongoDB, storing: + - Native BSON datetime types for TTL indexing (created_at, expires_at) + - Values in value.object (native BSON) or value.string (JSON) fields + - Support for migration between native and string storage modes + + The native storage mode is recommended for new deployments as it allows + efficient querying of value fields, while string mode provides backward + compatibility with older data. + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the MongoDB adapter. + + Args: + native_storage: If True (default), store value as native BSON dict in value.object field. + If False, store as JSON string in value.string field for backward compatibility. + """ + self.native_storage = native_storage + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: # noqa: ARG002 + """Convert a ManagedEntry to a MongoDB document. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: The collection (unused, for interface compatibility). + + Returns: + A MongoDB document with key, value, and BSON datetime metadata. + """ + document: dict[str, Any] = {"key": key, "value": {}} + + # We convert to JSON even if we don't need to, this ensures that the value we were provided + # can be serialized to JSON which helps ensure compatibility across stores + json_str = entry.value_as_json + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["object"] = entry.value_as_dict + else: + document["value"]["string"] = json_str + + # Add metadata fields as BSON datetimes for TTL indexing + if entry.created_at: + document["created_at"] = entry.created_at + if entry.expires_at: + document["expires_at"] = entry.expires_at + + return document + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a MongoDB document back to a ManagedEntry. + + This method supports both native (object) and legacy (string) storage modes, + and properly handles BSON datetime types for metadata. + + Args: + data: The MongoDB document to deserialize. + + Returns: + A ManagedEntry reconstructed from the document. + + Raises: + DeserializationError: If data is not a dict or is malformed. + """ + if not isinstance(data, dict): + msg = "Expected MongoDB document to be a dict" + raise DeserializationError(msg) + + document = data + + if not (value_field := document.get("value")): + msg = "Value field not found" + raise DeserializationError(msg) + + if not isinstance(value_field, dict): + msg = "Expected `value` field to be an object" + raise DeserializationError(msg) + + value_holder: dict[str, Any] = verify_dict(obj=value_field) + + entry_data: dict[str, Any] = {} + + # Mongo stores datetimes without timezones as UTC so we mark them as UTC + # Import timezone here to avoid circular import + from key_value.shared.utils.time_to_live import timezone + + if created_at_datetime := document.get("created_at"): + if not isinstance(created_at_datetime, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) + + if expires_at_datetime := document.get("expires_at"): + if not isinstance(expires_at_datetime, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) + + # Support both native (object) and legacy (string) storage + if value_object := value_holder.get("object"): + return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) + + if value_string := value_holder.get("string"): + return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) + + msg = "Expected `value` field to be an object with `object` or `string` subfield" + raise DeserializationError(msg) + + +class ElasticsearchAdapter(SerializationAdapter): + """Adapter for Elasticsearch with support for native and string storage modes. + + This adapter supports two storage modes: + - Native mode: Stores values as flattened dicts for efficient querying + - String mode: Stores values as JSON strings for backward compatibility + + Elasticsearch-specific features: + - Stores collection name in the document for multi-tenancy + - Uses ISO format for datetime fields + - Supports migration between storage modes + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the Elasticsearch adapter. + + Args: + native_storage: If True (default), store values as flattened dicts. + If False, store values as JSON strings. + """ + self.native_storage = native_storage + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: + """Convert a ManagedEntry to an Elasticsearch document. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: The collection name to store in the document. + + Returns: + An Elasticsearch document dict with collection, key, value, and metadata. + """ + document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}} + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["flattened"] = entry.value_as_dict + else: + document["value"]["string"] = entry.value_as_json + + if entry.created_at: + document["created_at"] = entry.created_at.isoformat() + if entry.expires_at: + document["expires_at"] = entry.expires_at.isoformat() + + return document + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert an Elasticsearch document back to a ManagedEntry. + + This method supports both native (flattened) and string storage modes, + trying the flattened field first and falling back to the string field. + This allows for seamless migration between storage modes. + + Args: + data: The Elasticsearch document to deserialize. + + Returns: + A ManagedEntry reconstructed from the document. + + Raises: + DeserializationError: If data is not a dict or is malformed. + """ + if not isinstance(data, dict): + msg = "Expected Elasticsearch document to be a dict" + raise DeserializationError(msg) + + document = data + value: dict[str, Any] = {} + + raw_value = document.get("value") + + # Try flattened field first, fall back to string field + if not raw_value or not isinstance(raw_value, dict): + msg = "Value field not found or invalid type" + raise DeserializationError(msg) + + if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + value = verify_dict(obj=value_flattened) + elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + if not isinstance(value_str, str): + msg = "Value in `value` field is not a string" + raise DeserializationError(msg) + value = load_from_json(value_str) + else: + msg = "Value field not found or invalid type" + raise DeserializationError(msg) + + created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at")) + expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at")) + + return ManagedEntry( + value=value, + created_at=created_at, + expires_at=expires_at, + ) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py index e57cc10c..1c02abda 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py @@ -252,7 +252,6 @@ def _put_managed_entries( created_at: datetime, expires_at: datetime | None, ) -> None: - """Store multiple managed entries by key in the specified collection. Args: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py index 8ab6f618..c56b0a23 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py @@ -9,9 +9,10 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, NUMBERS, sanitize_string -from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str +from key_value.shared.utils.serialization import ElasticsearchAdapter, SerializationAdapter +from key_value.shared.utils.time_to_live import now_as_epoch from typing_extensions import override from key_value.sync.code_gen.stores.base import ( @@ -64,50 +65,6 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." -def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]: - document: dict[str, Any] = {"collection": collection, "key": key, "value": {}} - - # Store in appropriate field based on mode - if native_storage: - document["value"]["flattened"] = managed_entry.value_as_dict - else: - document["value"]["string"] = managed_entry.value_as_json - - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at.isoformat() - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at.isoformat() - - return document - - -def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry: - value: dict[str, Any] = {} - - raw_value = source.get("value") - - # Try flattened field first, fall back to string field - if not raw_value or not isinstance(raw_value, dict): - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - value = verify_dict(obj=value_flattened) - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - if not isinstance(value_str, str): - msg = "Value in `value` field is not a string" - raise DeserializationError(msg) - value = load_from_json(value_str) - else: - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at")) - - return ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) - - class ElasticsearchStore( BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyCollectionStore, BaseCullStore, BaseContextManagerStore, BaseStore ): @@ -121,6 +78,8 @@ class ElasticsearchStore( _native_storage: bool + _adapter: SerializationAdapter + @overload def __init__( self, *, elasticsearch_client: Elasticsearch, index_prefix: str, native_storage: bool = True, default_collection: str | None = None @@ -173,6 +132,7 @@ def __init__( self._index_prefix = index_prefix self._native_storage = native_storage self._is_serverless = False + self._adapter = ElasticsearchAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -220,7 +180,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None try: - return source_to_managed_entry(source=source) + return self._adapter.from_storage(data=source) except DeserializationError: return None @@ -253,7 +213,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ continue try: - entries_by_id[doc_id] = source_to_managed_entry(source=source) + entries_by_id[doc_id] = self._adapter.from_storage(data=source) except DeserializationError as e: logger.error( "Failed to deserialize Elasticsearch document in batch operation", @@ -274,9 +234,10 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage index_name: str = self._sanitize_index_name(collection=collection) document_id: str = self._sanitize_document_id(key=key) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + if not isinstance(document, dict): + msg = "Elasticsearch adapter must return dict" + raise TypeError(msg) try: _ = self._client.index(index=index_name, id=document_id, body=document, refresh=self._should_refresh_on_put) @@ -309,9 +270,10 @@ def _put_managed_entries( index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id) - document: dict[str, Any] = managed_entry_to_document( - collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage - ) + document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) + if not isinstance(document, dict): + msg = "Elasticsearch adapter must return dict" + raise TypeError(msg) operations.extend([index_action, document]) try: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py index 1d4c18e0..8c3e6c48 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py @@ -1,15 +1,14 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. -from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string -from key_value.shared.utils.time_to_live import timezone +from key_value.shared.utils.serialization import MongoDBAdapter, SerializationAdapter from typing_extensions import Self, override from key_value.sync.code_gen.stores.base import ( @@ -43,219 +42,6 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" -class SerializationAdapter(ABC): - """Base class for store-specific serialization adapters. - - Adapters encapsulate the logic for converting between ManagedEntry objects - and store-specific storage formats. This provides a consistent interface - while allowing each store to optimize its serialization strategy. - """ - - @abstractmethod - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: - """Convert a ManagedEntry to the store's storage format. - - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: Optional collection name. - - Returns: - The serialized representation (dict or str depending on store). - """ - ... - - @abstractmethod - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert stored data back to a ManagedEntry. - - Args: - data: The stored representation to deserialize. - - Returns: - A ManagedEntry reconstructed from storage. - - Raises: - DeserializationError: If the data cannot be deserialized. - """ - ... - - -class MongoDBAdapter(SerializationAdapter): - """MongoDB-specific serialization adapter. - - Stores entries with native BSON datetime types for TTL indexing, - while maintaining the value.object/value.string structure for compatibility. - """ - - def __init__(self, *, native_storage: bool = True) -> None: - """Initialize the MongoDB adapter. - - Args: - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. - """ - self.native_storage = native_storage - - @override - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: - """Convert a ManagedEntry to a MongoDB document.""" - document: dict[str, Any] = {"key": key, "value": {}} - - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores - json_str = entry.value_as_json - - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["object"] = entry.value_as_dict - else: - document["value"]["string"] = json_str - - # Add metadata fields as BSON datetimes for TTL indexing - if entry.created_at: - document["created_at"] = entry.created_at - if entry.expires_at: - document["expires_at"] = entry.expires_at - - return document - - @override - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry.""" - if not isinstance(data, dict): - msg = "Expected MongoDB document to be a dict" - raise DeserializationError(msg) - - document = data - - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - entry_data: dict[str, Any] = {} - - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): - msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): - msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - # Support both native (object) and legacy (string) storage - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) - - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) - - -def document_to_managed_entry(document: dict[str, Any]) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry. - - This function deserializes a MongoDB document (created by `managed_entry_to_document`) back to a - ManagedEntry object. It supports both native BSON storage (dict in value.object field) and legacy - JSON string storage (string in value.string field) for migration support. - - Args: - document: The MongoDB document to convert. - - Returns: - A ManagedEntry object reconstructed from the document. - """ - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - data: dict[str, Any] = {} - - # The Value field is an object with two possible fields: `object` and `string` - # - `object`: The value is a native BSON dict - # - `string`: The value is a JSON string - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): - msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): - msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **data}, stringified_value=True) - - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) - - -def managed_entry_to_document(key: str, managed_entry: ManagedEntry, *, native_storage: bool = True) -> dict[str, Any]: - """Convert a ManagedEntry to a MongoDB document for storage. - - This function serializes a ManagedEntry to a MongoDB document format, including the key and all - metadata (TTL, creation, and expiration timestamps). The value storage format depends on the - native_storage parameter. - - Args: - key: The key associated with this entry. - managed_entry: The ManagedEntry to serialize. - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. - - Returns: - A MongoDB document dict containing the key, value, and all metadata. - """ - document: dict[str, Any] = {"key": key, "value": {}} - - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores. For example, - # Mongo can natively handle datetime objects which other stores cannot, if we don't convert to JSON, - # then using py-key-value with Mongo will return different values than if we used another store. - json_str = managed_entry.value_as_json - - # Store in appropriate field based on mode - if native_storage: - document["value"]["object"] = managed_entry.value_as_dict - else: - document["value"]["string"] = json_str - - # Add metadata fields - if managed_entry.created_at: - document["created_at"] = managed_entry.created_at - if managed_entry.expires_at: - document["expires_at"] = managed_entry.expires_at - - return document - - class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): """MongoDB-based key-value store using pymongo.""" diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py index cd302601..76238fa5 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py @@ -1,7 +1,6 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. -from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Any, overload @@ -11,6 +10,7 @@ from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import FullJsonAdapter, SerializationAdapter from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -25,95 +25,6 @@ PAGE_LIMIT = 10000 -class SerializationAdapter(ABC): - """Base class for store-specific serialization adapters. - - Adapters encapsulate the logic for converting between ManagedEntry objects - and store-specific storage formats. - """ - - @abstractmethod - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: - """Convert a ManagedEntry to the store's storage format. - - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: Optional collection name. - - Returns: - The serialized representation (dict or str depending on store). - """ - ... - - @abstractmethod - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert stored data back to a ManagedEntry. - - Args: - data: The stored representation to deserialize. - - Returns: - A ManagedEntry reconstructed from storage. - - Raises: - DeserializationError: If the data cannot be deserialized. - """ - ... - - -class RedisAdapter(SerializationAdapter): - """Redis-specific serialization adapter. - - Stores entries as JSON strings in Redis. - """ - - @override - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: - """Convert a ManagedEntry to a JSON string for Redis storage.""" - return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - @override - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a JSON string from Redis storage back to a ManagedEntry.""" - if not isinstance(data, str): - msg = "Expected Redis data to be a string" - raise DeserializationError(msg) - - return ManagedEntry.from_json(json_str=data, includes_metadata=True) - - -def managed_entry_to_json(managed_entry: ManagedEntry) -> str: - """Convert a ManagedEntry to a JSON string for Redis storage. - - This function serializes a ManagedEntry to JSON format including all metadata (TTL, creation, - and expiration timestamps). The serialization is designed to preserve all entry information - for round-trip conversion back to a ManagedEntry. - - Args: - managed_entry: The ManagedEntry to serialize. - - Returns: - A JSON string representation of the ManagedEntry with full metadata. - """ - return managed_entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - -def json_to_managed_entry(json_str: str) -> ManagedEntry: - """Convert a JSON string from Redis storage back to a ManagedEntry. - - This function deserializes a JSON string (created by `managed_entry_to_json`) back to a - ManagedEntry object, preserving all metadata including TTL, creation, and expiration timestamps. - - Args: - json_str: The JSON string to deserialize. - - Returns: - A ManagedEntry object reconstructed from the JSON string. - """ - return ManagedEntry.from_json(json_str=json_str, includes_metadata=True) - - class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" @@ -169,7 +80,7 @@ def __init__( self._client = Redis(host=host, port=port, db=db, password=password, decode_responses=True) self._stable_api = True - self._adapter = RedisAdapter() + self._adapter = FullJsonAdapter() super().__init__(default_collection=default_collection) diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py b/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py index 949f27c4..ec618f3d 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_dataclass.py @@ -57,95 +57,90 @@ class Order: product: Product paid: bool = False + FIXED_CREATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=12, minute=0, second=0, tzinfo=timezone.utc) FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) -SAMPLE_USER: User = User(name='John Doe', email='john.doe@example.com', age=30) -SAMPLE_USER_2: User = User(name='Jane Doe', email='jane.doe@example.com', age=25) -SAMPLE_PRODUCT: Product = Product(name='Widget', price=29.99, quantity=10) -SAMPLE_ADDRESS: Address = Address(street='123 Main St', city='Springfield', zip_code='12345') -SAMPLE_USER_WITH_ADDRESS: UserWithAddress = UserWithAddress(name='John Doe', age=30, address=SAMPLE_ADDRESS) +SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) +SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) +SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10) +SAMPLE_ADDRESS: Address = Address(street="123 Main St", city="Springfield", zip_code="12345") +SAMPLE_USER_WITH_ADDRESS: UserWithAddress = UserWithAddress(name="John Doe", age=30, address=SAMPLE_ADDRESS) SAMPLE_ORDER: Order = Order(created_at=FIXED_CREATED_AT, updated_at=FIXED_UPDATED_AT, user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) -TEST_COLLECTION: str = 'test_collection' -TEST_KEY: str = 'test_key' -TEST_KEY_2: str = 'test_key_2' +TEST_COLLECTION: str = "test_collection" +TEST_KEY: str = "test_key" +TEST_KEY_2: str = "test_key_2" class TestDataclassAdapter: - @pytest.fixture def store(self) -> MemoryStore: return MemoryStore() - @pytest.fixture def user_adapter(self, store: MemoryStore) -> DataclassAdapter[User]: return DataclassAdapter[User](key_value=store, dataclass_type=User) - @pytest.fixture def updated_user_adapter(self, store: MemoryStore) -> DataclassAdapter[UpdatedUser]: return DataclassAdapter[UpdatedUser](key_value=store, dataclass_type=UpdatedUser) - @pytest.fixture def product_adapter(self, store: MemoryStore) -> DataclassAdapter[Product]: return DataclassAdapter[Product](key_value=store, dataclass_type=Product) - @pytest.fixture def product_list_adapter(self, store: MemoryStore) -> DataclassAdapter[list[Product]]: return DataclassAdapter[list[Product]](key_value=store, dataclass_type=list[Product]) - @pytest.fixture def user_with_address_adapter(self, store: MemoryStore) -> DataclassAdapter[UserWithAddress]: return DataclassAdapter[UserWithAddress](key_value=store, dataclass_type=UserWithAddress) - @pytest.fixture def order_adapter(self, store: MemoryStore) -> DataclassAdapter[Order]: return DataclassAdapter[Order](key_value=store, dataclass_type=Order) - def test_simple_adapter(self, user_adapter: DataclassAdapter[User]): """Test basic put/get/delete operations with a simple dataclass.""" user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) cached_user: User | None = user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_user == SAMPLE_USER - + assert user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - def test_simple_adapter_with_default(self, user_adapter: DataclassAdapter[User]): """Test default value handling.""" assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER - + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 - - assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot([SAMPLE_USER_2, SAMPLE_USER]) - - def test_simple_adapter_with_validation_error_ignore(self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser]): + assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( + [SAMPLE_USER_2, SAMPLE_USER] + ) + + def test_simple_adapter_with_validation_error_ignore( + self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] + ): """Test that validation errors return None when raise_on_validation_error is False.""" user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) - + # UpdatedUser requires is_admin field which doesn't exist in stored User updated_user = updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert updated_user is None - - def test_simple_adapter_with_validation_error_raise(self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser]): + def test_simple_adapter_with_validation_error_raise( + self, user_adapter: DataclassAdapter[User], updated_user_adapter: DataclassAdapter[UpdatedUser] + ): """Test that validation errors raise DeserializationError when raise_on_validation_error is True.""" user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) updated_user_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] with pytest.raises(DeserializationError): updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - def test_nested_dataclass(self, user_with_address_adapter: DataclassAdapter[UserWithAddress]): """Test that nested dataclasses are properly serialized and deserialized.""" @@ -153,58 +148,56 @@ def test_nested_dataclass(self, user_with_address_adapter: DataclassAdapter[User cached_user: UserWithAddress | None = user_with_address_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_user == SAMPLE_USER_WITH_ADDRESS assert cached_user is not None - assert cached_user.address.street == '123 Main St' - + assert cached_user.address.street == "123 Main St" def test_complex_adapter(self, order_adapter: DataclassAdapter[Order]): """Test complex dataclass with nested objects and TTL.""" order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) cached_order: Order | None = order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_order == SAMPLE_ORDER - + assert order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - def test_complex_adapter_with_list(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): """Test list dataclass serialization with proper wrapping.""" product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] - + # We need to ensure our memory store doesn't hold an entry with an array raw_collection = store._cache.get(TEST_COLLECTION) # pyright: ignore[reportPrivateUsage] assert raw_collection is not None - + raw_entry = raw_collection.get(TEST_KEY) assert raw_entry is not None assert isinstance(raw_entry.value, dict) - assert raw_entry.value == snapshot({'items': [{'name': 'Widget', 'price': 29.99, 'quantity': 10}, {'name': 'Widget', 'price': 29.99, 'quantity': 10}]}) - + assert raw_entry.value == snapshot( + {"items": [{"name": "Widget", "price": 29.99, "quantity": 10}, {"name": "Widget", "price": 29.99, "quantity": 10}]} + ) + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - def test_batch_operations(self, user_adapter: DataclassAdapter[User]): """Test batch put/get/delete operations.""" keys = [TEST_KEY, TEST_KEY_2] users = [SAMPLE_USER, SAMPLE_USER_2] - + # Test put_many user_adapter.put_many(collection=TEST_COLLECTION, keys=keys, values=users) - + # Test get_many cached_users = user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) assert cached_users == users - + # Test delete_many deleted_count = user_adapter.delete_many(collection=TEST_COLLECTION, keys=keys) assert deleted_count == 2 - + # Verify deletion cached_users_after_delete = user_adapter.get_many(collection=TEST_COLLECTION, keys=keys) assert cached_users_after_delete == [None, None] - def test_ttl_operations(self, user_adapter: DataclassAdapter[User]): """Test TTL-related operations.""" @@ -214,31 +207,28 @@ def test_ttl_operations(self, user_adapter: DataclassAdapter[User]): assert user == SAMPLE_USER assert ttl is not None assert ttl > 0 - + # Test ttl_many user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value=SAMPLE_USER_2, ttl=20) ttl_results = user_adapter.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2]) assert len(ttl_results) == 2 assert ttl_results[0][0] == SAMPLE_USER assert ttl_results[1][0] == SAMPLE_USER_2 - def test_dataclass_validation_on_init(self, store: MemoryStore): """Test that non-dataclass types are rejected.""" - with pytest.raises(TypeError, match='is not a dataclass'): + with pytest.raises(TypeError, match="is not a dataclass"): DataclassAdapter[str](key_value=store, dataclass_type=str) # type: ignore[type-var] - def test_default_collection(self, store: MemoryStore): """Test that default collection is used when not specified.""" adapter = DataclassAdapter[User](key_value=store, dataclass_type=User, default_collection=TEST_COLLECTION) - + adapter.put(key=TEST_KEY, value=SAMPLE_USER) cached_user = adapter.get(key=TEST_KEY) assert cached_user == SAMPLE_USER - + assert adapter.delete(key=TEST_KEY) - def test_ttl_with_empty_list(self, product_list_adapter: DataclassAdapter[list[Product]]): """Test that TTL with empty list returns correctly (not None).""" @@ -247,22 +237,20 @@ def test_ttl_with_empty_list(self, product_list_adapter: DataclassAdapter[list[P assert value == [] assert ttl is not None assert ttl > 0 - def test_list_payload_missing_items_returns_none(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): """Test that list payload without 'items' wrapper returns None when raise_on_validation_error is False.""" # Manually insert malformed payload without the 'items' wrapper # The payload is a dict but without the expected 'items' key for list models - malformed_payload: dict[str, Any] = {'wrong': []} + malformed_payload: dict[str, Any] = {"wrong": []} store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - def test_list_payload_missing_items_raises(self, product_list_adapter: DataclassAdapter[list[Product]], store: MemoryStore): """Test that list payload without 'items' wrapper raises DeserializationError when configured.""" product_list_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] # Manually insert malformed payload without the 'items' wrapper - malformed_payload: dict[str, Any] = {'wrong': []} + malformed_payload: dict[str, Any] = {"wrong": []} store.put(collection=TEST_COLLECTION, key=TEST_KEY, value=malformed_payload) with pytest.raises(DeserializationError, match="missing 'items'"): product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py index 58906c21..e52f9e8f 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_pydantic.py @@ -37,183 +37,186 @@ class Order(BaseModel): product: Product paid: bool + FIXED_CREATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=12, minute=0, second=0, tzinfo=timezone.utc) FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc) -SAMPLE_USER: User = User(name='John Doe', email='john.doe@example.com', age=30) -SAMPLE_USER_2: User = User(name='Jane Doe', email='jane.doe@example.com', age=25) -SAMPLE_PRODUCT: Product = Product(name='Widget', price=29.99, quantity=10, url=AnyHttpUrl(url='https://example.com')) +SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30) +SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25) +SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10, url=AnyHttpUrl(url="https://example.com")) SAMPLE_ORDER: Order = Order(created_at=datetime.now(), updated_at=datetime.now(), user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False) -TEST_COLLECTION: str = 'test_collection' -TEST_KEY: str = 'test_key' -TEST_KEY_2: str = 'test_key_2' +TEST_COLLECTION: str = "test_collection" +TEST_KEY: str = "test_key" +TEST_KEY_2: str = "test_key_2" def model_type_from_log_record(record: LogRecord) -> str: - if not hasattr(record, 'model_type'): - msg = 'Log record does not have a model_type attribute' + if not hasattr(record, "model_type"): + msg = "Log record does not have a model_type attribute" raise ValueError(msg) return record.model_type # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] def error_from_log_record(record: LogRecord) -> str: - if not hasattr(record, 'error'): - msg = 'Log record does not have an error attribute' + if not hasattr(record, "error"): + msg = "Log record does not have an error attribute" raise ValueError(msg) return record.error # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] def errors_from_log_record(record: LogRecord) -> list[str]: - if not hasattr(record, 'errors'): - msg = 'Log record does not have an errors attribute' + if not hasattr(record, "errors"): + msg = "Log record does not have an errors attribute" raise ValueError(msg) return record.errors # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] class TestPydanticAdapter: - @pytest.fixture def store(self) -> MemoryStore: return MemoryStore() - @pytest.fixture def user_adapter(self, store: MemoryStore) -> PydanticAdapter[User]: return PydanticAdapter[User](key_value=store, pydantic_model=User) - @pytest.fixture def updated_user_adapter(self, store: MemoryStore) -> PydanticAdapter[UpdatedUser]: return PydanticAdapter[UpdatedUser](key_value=store, pydantic_model=UpdatedUser) - @pytest.fixture def product_adapter(self, store: MemoryStore) -> PydanticAdapter[Product]: return PydanticAdapter[Product](key_value=store, pydantic_model=Product) - @pytest.fixture def product_list_adapter(self, store: MemoryStore) -> PydanticAdapter[list[Product]]: return PydanticAdapter[list[Product]](key_value=store, pydantic_model=list[Product]) - @pytest.fixture def order_adapter(self, store: MemoryStore) -> PydanticAdapter[Order]: return PydanticAdapter[Order](key_value=store, pydantic_model=Order) - def test_simple_adapter(self, user_adapter: PydanticAdapter[User]): user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) cached_user: User | None = user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_user == SAMPLE_USER - + assert user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) - + assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]): assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER - + user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2) assert user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2 - - assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot([SAMPLE_USER_2, SAMPLE_USER]) - + + assert user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot( + [SAMPLE_USER_2, SAMPLE_USER] + ) def test_simple_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]]): product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT]) cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] - + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - - def test_simple_adapter_with_validation_error_ignore(self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser]): + def test_simple_adapter_with_validation_error_ignore( + self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] + ): user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) - + updated_user = updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert updated_user is None - - def test_simple_adapter_with_validation_error_raise(self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser]): + def test_simple_adapter_with_validation_error_raise( + self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser] + ): user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) updated_user_adapter._raise_on_validation_error = True # pyright: ignore[reportPrivateUsage] with pytest.raises(DeserializationError): updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - def test_complex_adapter(self, order_adapter: PydanticAdapter[Order]): order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10) assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) == SAMPLE_ORDER - + assert order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - def test_complex_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore): product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10) cached_products: list[Product] | None = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) assert cached_products == [SAMPLE_PRODUCT, SAMPLE_PRODUCT] - + # We need to ensure our memory store doesnt hold an entry with an array raw_collection = store._cache.get(TEST_COLLECTION) # pyright: ignore[reportPrivateUsage] assert raw_collection is not None - + raw_entry = raw_collection.get(TEST_KEY) assert raw_entry is not None assert isinstance(raw_entry.value, dict) - assert raw_entry.value == snapshot({'items': [{'name': 'Widget', 'price': 29.99, 'quantity': 10, 'url': 'https://example.com/'}, {'name': 'Widget', 'price': 29.99, 'quantity': 10, 'url': 'https://example.com/'}]}) - + assert raw_entry.value == snapshot( + { + "items": [ + {"name": "Widget", "price": 29.99, "quantity": 10, "url": "https://example.com/"}, + {"name": "Widget", "price": 29.99, "quantity": 10, "url": "https://example.com/"}, + ] + } + ) + assert product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY) assert product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None - - def test_validation_error_logging(self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser], caplog: pytest.LogCaptureFixture): + def test_validation_error_logging( + self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser], caplog: pytest.LogCaptureFixture + ): """Test that validation errors are logged when raise_on_validation_error=False.""" import logging - + # Store a User, then try to retrieve as UpdatedUser (missing is_admin field) user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER) - + with caplog.at_level(logging.ERROR): updated_user = updated_user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - + # Should return None due to validation failure assert updated_user is None - + # Check that an error was logged assert len(caplog.records) == 1 record = caplog.records[0] - assert record.levelname == 'ERROR' - assert 'Validation failed' in record.message - assert model_type_from_log_record(record) == 'Pydantic model' - + assert record.levelname == "ERROR" + assert "Validation failed" in record.message + assert model_type_from_log_record(record) == "Pydantic model" + errors = errors_from_log_record(record) assert len(errors) == 1 - assert 'is_admin' in str(errors[0]) - + assert "is_admin" in str(errors[0]) - def test_list_validation_error_logging(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore, caplog: pytest.LogCaptureFixture): + def test_list_validation_error_logging( + self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore, caplog: pytest.LogCaptureFixture + ): """Test that missing 'items' wrapper is logged for list models.""" import logging - + # Manually store invalid data (missing 'items' wrapper) - store.put(collection=TEST_COLLECTION, key=TEST_KEY, value={'invalid': 'data'}) - + store.put(collection=TEST_COLLECTION, key=TEST_KEY, value={"invalid": "data"}) + with caplog.at_level(logging.ERROR): result = product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) - + # Should return None due to missing 'items' wrapper assert result is None - + # Check that an error was logged assert len(caplog.records) == 1 record = caplog.records[0] - assert record.levelname == 'ERROR' + assert record.levelname == "ERROR" assert "Missing 'items' wrapper" in record.message - assert model_type_from_log_record(record) == 'Pydantic model' + assert model_type_from_log_record(record) == "Pydantic model" error = error_from_log_record(record) assert "missing 'items' wrapper" in str(error) diff --git a/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py b/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py index 1712dafc..1653f1e4 100644 --- a/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py +++ b/key-value/key-value-sync/tests/code_gen/adapters/test_raise.py @@ -19,50 +19,50 @@ def adapter(store: MemoryStore) -> RaiseOnMissingAdapter: def test_get(adapter: RaiseOnMissingAdapter): - adapter.put(collection='test', key='test', value={'test': 'test'}) - assert adapter.get(collection='test', key='test') == {'test': 'test'} + adapter.put(collection="test", key="test", value={"test": "test"}) + assert adapter.get(collection="test", key="test") == {"test": "test"} def test_get_missing(adapter: RaiseOnMissingAdapter): with pytest.raises(MissingKeyError): - _ = adapter.get(collection='test', key='test', raise_on_missing=True) + _ = adapter.get(collection="test", key="test", raise_on_missing=True) def test_get_many(adapter: RaiseOnMissingAdapter): - adapter.put(collection='test', key='test', value={'test': 'test'}) - adapter.put(collection='test', key='test_2', value={'test': 'test_2'}) - assert adapter.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] + adapter.put(collection="test", key="test", value={"test": "test"}) + adapter.put(collection="test", key="test_2", value={"test": "test_2"}) + assert adapter.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] def test_get_many_missing(adapter: RaiseOnMissingAdapter): - adapter.put(collection='test', key='test', value={'test': 'test'}) + adapter.put(collection="test", key="test", value={"test": "test"}) with pytest.raises(MissingKeyError): - _ = adapter.get_many(collection='test', keys=['test', 'test_2'], raise_on_missing=True) + _ = adapter.get_many(collection="test", keys=["test", "test_2"], raise_on_missing=True) def test_ttl(adapter: RaiseOnMissingAdapter): - adapter.put(collection='test', key='test', value={'test': 'test'}, ttl=60) - (value, ttl) = adapter.ttl(collection='test', key='test') - assert value == {'test': 'test'} + adapter.put(collection="test", key="test", value={"test": "test"}, ttl=60) + (value, ttl) = adapter.ttl(collection="test", key="test") + assert value == {"test": "test"} assert ttl is not None def test_ttl_missing(adapter: RaiseOnMissingAdapter): with pytest.raises(MissingKeyError): - _ = adapter.ttl(collection='test', key='test', raise_on_missing=True) + _ = adapter.ttl(collection="test", key="test", raise_on_missing=True) def test_ttl_many(adapter: RaiseOnMissingAdapter): - adapter.put(collection='test', key='test', value={'test': 'test'}, ttl=60) - adapter.put(collection='test', key='test_2', value={'test': 'test_2'}, ttl=120) - results = adapter.ttl_many(collection='test', keys=['test', 'test_2']) - assert results[0][0] == {'test': 'test'} + adapter.put(collection="test", key="test", value={"test": "test"}, ttl=60) + adapter.put(collection="test", key="test_2", value={"test": "test_2"}, ttl=120) + results = adapter.ttl_many(collection="test", keys=["test", "test_2"]) + assert results[0][0] == {"test": "test"} assert results[0][1] is not None - assert results[1][0] == {'test': 'test_2'} + assert results[1][0] == {"test": "test_2"} assert results[1][1] is not None def test_ttl_many_missing(adapter: RaiseOnMissingAdapter): - adapter.put(collection='test', key='test', value={'test': 'test'}, ttl=60) + adapter.put(collection="test", key="test", value={"test": "test"}, ttl=60) with pytest.raises(MissingKeyError): - _ = adapter.ttl_many(collection='test', keys=['test', 'test_2'], raise_on_missing=True) + _ = adapter.ttl_many(collection="test", keys=["test", "test_2"], raise_on_missing=True) diff --git a/key-value/key-value-sync/tests/code_gen/cases.py b/key-value/key-value-sync/tests/code_gen/cases.py index a1aa3250..0fffeebe 100644 --- a/key-value/key-value-sync/tests/code_gen/cases.py +++ b/key-value/key-value-sync/tests/code_gen/cases.py @@ -7,19 +7,57 @@ FIXED_DATETIME = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) FIXED_TIME = FIXED_DATETIME.time() -LARGE_STRING: str = 'a' * 10000 # 10KB -LARGE_INT: int = 1 * 10 ** 18 # 18 digits -LARGE_FLOAT: float = 1.0 * 10 ** 63 # 63 digits +LARGE_STRING: str = "a" * 10000 # 10KB +LARGE_INT: int = 1 * 10**18 # 18 digits +LARGE_FLOAT: float = 1.0 * 10**63 # 63 digits -SIMPLE_CASE: dict[str, Any] = {'key_1': 'value_1', 'key_2': 1, 'key_3': 1.0, 'key_4': [1, 2, 3], 'key_5': {'nested': 'value'}, 'key_6': True, 'key_7': False, 'key_8': None} +SIMPLE_CASE: dict[str, Any] = { + "key_1": "value_1", + "key_2": 1, + "key_3": 1.0, + "key_4": [1, 2, 3], + "key_5": {"nested": "value"}, + "key_6": True, + "key_7": False, + "key_8": None, +} SIMPLE_CASE_JSON: str = '{"key_1": "value_1", "key_2": 1, "key_3": 1.0, "key_4": [1, 2, 3], "key_5": {"nested": "value"}, "key_6": true, "key_7": false, "key_8": null}' # ({"key": (1, 2, 3)}, '{"key": [1, 2, 3]}'), -DICTIONARY_TO_JSON_TEST_CASES: list[tuple[dict[str, Any], str]] = [({'key': 'value'}, '{"key": "value"}'), ({'key': 1}, '{"key": 1}'), ({'key': 1.0}, '{"key": 1.0}'), ({'key': [1, 2, 3]}, '{"key": [1, 2, 3]}'), ({'key': {'nested': 'value'}}, '{"key": {"nested": "value"}}'), ({'key': True}, '{"key": true}'), ({'key': False}, '{"key": false}'), ({'key': None}, '{"key": null}'), ({'key': {'int': 1, 'float': 1.0, 'list': [1, 2, 3], 'dict': {'nested': 'value'}, 'bool': True, 'null': None}}, '{"key": {"int": 1, "float": 1.0, "list": [1, 2, 3], "dict": {"nested": "value"}, "bool": true, "null": null}}'), ({'key': LARGE_STRING}, f'{{"key": "{LARGE_STRING}"}}'), ({'key': LARGE_INT}, f'{{"key": {LARGE_INT}}}'), ({'key': LARGE_FLOAT}, f'{{"key": {LARGE_FLOAT}}}')] +DICTIONARY_TO_JSON_TEST_CASES: list[tuple[dict[str, Any], str]] = [ + ({"key": "value"}, '{"key": "value"}'), + ({"key": 1}, '{"key": 1}'), + ({"key": 1.0}, '{"key": 1.0}'), + ({"key": [1, 2, 3]}, '{"key": [1, 2, 3]}'), + ({"key": {"nested": "value"}}, '{"key": {"nested": "value"}}'), + ({"key": True}, '{"key": true}'), + ({"key": False}, '{"key": false}'), + ({"key": None}, '{"key": null}'), + ( + {"key": {"int": 1, "float": 1.0, "list": [1, 2, 3], "dict": {"nested": "value"}, "bool": True, "null": None}}, + '{"key": {"int": 1, "float": 1.0, "list": [1, 2, 3], "dict": {"nested": "value"}, "bool": true, "null": null}}', + ), + ({"key": LARGE_STRING}, f'{{"key": "{LARGE_STRING}"}}'), + ({"key": LARGE_INT}, f'{{"key": {LARGE_INT}}}'), + ({"key": LARGE_FLOAT}, f'{{"key": {LARGE_FLOAT}}}'), +] # "tuple", -DICTIONARY_TO_JSON_TEST_CASES_NAMES: list[str] = ['string', 'int', 'float', 'list', 'dict', 'bool-true', 'bool-false', 'null', 'dict-nested', 'large-string', 'large-int', 'large-float'] +DICTIONARY_TO_JSON_TEST_CASES_NAMES: list[str] = [ + "string", + "int", + "float", + "list", + "dict", + "bool-true", + "bool-false", + "null", + "dict-nested", + "large-string", + "large-int", + "large-float", +] OBJECT_TEST_CASES: list[dict[str, Any]] = [test_case[0] for test_case in DICTIONARY_TO_JSON_TEST_CASES] diff --git a/key-value/key-value-sync/tests/code_gen/conftest.py b/key-value/key-value-sync/tests/code_gen/conftest.py index f2b9505b..3d7b2051 100644 --- a/key-value/key-value-sync/tests/code_gen/conftest.py +++ b/key-value/key-value-sync/tests/code_gen/conftest.py @@ -22,11 +22,10 @@ @contextmanager def try_import() -> Iterator[Callable[[], bool]]: import_success = False - def check_import() -> bool: return import_success - + try: yield check_import except ImportError: @@ -44,70 +43,70 @@ def docker_client() -> DockerClient: return get_docker_client() -def docker_logs(name: str, print_logs: bool=False, raise_on_error: bool=False, log_level: int=logging.INFO) -> list[str]: +def docker_logs(name: str, print_logs: bool = False, raise_on_error: bool = False, log_level: int = logging.INFO) -> list[str]: client = get_docker_client() try: - logs: list[str] = client.containers.get(name).logs().decode('utf-8').splitlines() + logs: list[str] = client.containers.get(name).logs().decode("utf-8").splitlines() except Exception: - logger.info(f'Container {name} failed to get logs') + logger.info(f"Container {name} failed to get logs") if raise_on_error: raise return [] - + if print_logs: - logger.info(f'Container {name} logs:') + logger.info(f"Container {name} logs:") for log in logs: logger.log(log_level, log) - + return logs -def docker_get(name: str, raise_on_not_found: bool=False) -> Container | None: +def docker_get(name: str, raise_on_not_found: bool = False) -> Container | None: from docker.errors import NotFound - + client = get_docker_client() try: return client.containers.get(name) except NotFound: - logger.info(f'Container {name} failed to get') + logger.info(f"Container {name} failed to get") if raise_on_not_found: raise return None -def docker_pull(image: str, raise_on_error: bool=False) -> bool: - logger.info(f'Pulling image {image}') +def docker_pull(image: str, raise_on_error: bool = False) -> bool: + logger.info(f"Pulling image {image}") client = get_docker_client() try: client.images.pull(image) except Exception: - logger.exception(f'Image {image} failed to pull') + logger.exception(f"Image {image} failed to pull") if raise_on_error: raise return False return True -def docker_stop(name: str, raise_on_error: bool=False) -> bool: - logger.info(f'Stopping container {name}') - +def docker_stop(name: str, raise_on_error: bool = False) -> bool: + logger.info(f"Stopping container {name}") + if not (container := docker_get(name=name, raise_on_not_found=False)): return False - + try: container.stop() except Exception: - logger.info(f'Container {name} failed to stop') + logger.info(f"Container {name} failed to stop") if raise_on_error: raise return False - - logger.info(f'Container {name} stopped') + + logger.info(f"Container {name} stopped") return True -def docker_wait_container_gone(name: str, max_tries: int=10, wait_time: float=1.0) -> bool: - logger.info(f'Waiting for container {name} to be gone') +def docker_wait_container_gone(name: str, max_tries: int = 10, wait_time: float = 1.0) -> bool: + logger.info(f"Waiting for container {name} to be gone") count = 0 while count < max_tries: if not docker_get(name=name, raise_on_not_found=False): @@ -117,51 +116,53 @@ def docker_wait_container_gone(name: str, max_tries: int=10, wait_time: float=1. return False -def docker_rm(name: str, raise_on_error: bool=False) -> bool: - logger.info(f'Removing container {name}') - +def docker_rm(name: str, raise_on_error: bool = False) -> bool: + logger.info(f"Removing container {name}") + if not (container := docker_get(name=name, raise_on_not_found=False)): return False - + try: container.remove() except Exception: - logger.info(f'Container {name} failed to remove') + logger.info(f"Container {name} failed to remove") if raise_on_error: raise return False - logger.info(f'Container {name} removed') + logger.info(f"Container {name} removed") return True -def docker_run(name: str, image: str, ports: dict[str, int], environment: dict[str, str], raise_on_error: bool=False) -> bool: - logger.info(f'Running container {name} with image {image} and ports {ports}') +def docker_run(name: str, image: str, ports: dict[str, int], environment: dict[str, str], raise_on_error: bool = False) -> bool: + logger.info(f"Running container {name} with image {image} and ports {ports}") client = get_docker_client() try: client.containers.run(name=name, image=image, ports=ports, environment=environment, detach=True) except Exception: - logger.exception(f'Container {name} failed to run') + logger.exception(f"Container {name} failed to run") if raise_on_error: raise return False - logger.info(f'Container {name} running') + logger.info(f"Container {name} running") return True @contextmanager -def docker_container(name: str, image: str, ports: dict[str, int], environment: dict[str, str] | None=None, raise_on_error: bool=True) -> Iterator[None]: - logger.info(f'Creating container {name} with image {image} and ports {ports}') +def docker_container( + name: str, image: str, ports: dict[str, int], environment: dict[str, str] | None = None, raise_on_error: bool = True +) -> Iterator[None]: + logger.info(f"Creating container {name} with image {image} and ports {ports}") try: docker_pull(image=image, raise_on_error=True) docker_stop(name=name, raise_on_error=False) docker_rm(name=name, raise_on_error=False) docker_wait_container_gone(name=name, max_tries=10, wait_time=1.0) docker_run(name=name, image=image, ports=ports, environment=environment or {}, raise_on_error=True) - logger.info(f'Container {name} created') + logger.info(f"Container {name} created") yield docker_logs(name, print_logs=True, raise_on_error=False) except Exception: - logger.info(f'Creating container {name} failed') + logger.info(f"Creating container {name} failed") docker_logs(name, print_logs=True, raise_on_error=False, log_level=logging.ERROR) if raise_on_error: raise @@ -170,8 +171,8 @@ def docker_container(name: str, image: str, ports: dict[str, int], environment: docker_stop(name, raise_on_error=False) docker_rm(name, raise_on_error=False) docker_wait_container_gone(name=name, max_tries=10, wait_time=1.0) - - logger.info(f'Container {name} stopped and removed') + + logger.info(f"Container {name} stopped and removed") return @@ -189,7 +190,7 @@ def running_in_event_loop() -> bool: def detect_docker() -> bool: try: - result = subprocess.run(['docker', 'ps'], check=False, capture_output=True, text=True) # noqa: S607 + result = subprocess.run(["docker", "ps"], check=False, capture_output=True, text=True) # noqa: S607 except Exception: return False else: @@ -197,19 +198,19 @@ def detect_docker() -> bool: def detect_on_ci() -> bool: - return os.getenv('CI', 'false') == 'true' + return os.getenv("CI", "false") == "true" def detect_on_windows() -> bool: - return platform.system() == 'Windows' + return platform.system() == "Windows" def detect_on_macos() -> bool: - return platform.system() == 'Darwin' + return platform.system() == "Darwin" def detect_on_linux() -> bool: - return platform.system() == 'Linux' + return platform.system() == "Linux" def should_run_docker_tests() -> bool: diff --git a/key-value/key-value-sync/tests/code_gen/protocols/test_types.py b/key-value/key-value-sync/tests/code_gen/protocols/test_types.py index b883d85a..2d4abd61 100644 --- a/key-value/key-value-sync/tests/code_gen/protocols/test_types.py +++ b/key-value/key-value-sync/tests/code_gen/protocols/test_types.py @@ -6,16 +6,15 @@ def test_key_value_protocol(): - def test_protocol(key_value: KeyValue): - assert key_value.get(collection='test', key='test') is None - key_value.put(collection='test', key='test', value={'test': 'test'}) - assert key_value.delete(collection='test', key='test') - key_value.put(collection='test', key='test_2', value={'test': 'test'}) - + assert key_value.get(collection="test", key="test") is None + key_value.put(collection="test", key="test", value={"test": "test"}) + assert key_value.delete(collection="test", key="test") + key_value.put(collection="test", key="test_2", value={"test": "test"}) + memory_store = MemoryStore() - + test_protocol(key_value=memory_store) - - assert memory_store.get(collection='test', key='test') is None - assert memory_store.get(collection='test', key='test_2') == {'test': 'test'} + + assert memory_store.get(collection="test", key="test") is None + assert memory_store.get(collection="test", key="test_2") == {"test": "test"} diff --git a/key-value/key-value-sync/tests/code_gen/stores/base.py b/key-value/key-value-sync/tests/code_gen/stores/base.py index 23fc9bb6..8f191c55 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/base.py +++ b/key-value/key-value-sync/tests/code_gen/stores/base.py @@ -20,292 +20,253 @@ class BaseStoreTests(ABC): - def eventually_consistent(self) -> None: # noqa: B027 - 'Subclasses can override this to wait for eventually consistent operations.' - + "Subclasses can override this to wait for eventually consistent operations." @pytest.fixture @abstractmethod - def store(self) -> BaseStore | Generator[BaseStore, None, None]: - ... - + def store(self) -> BaseStore | Generator[BaseStore, None, None]: ... @pytest.mark.timeout(60) def test_store(self, store: BaseStore): """Tests that the store is a valid KeyValueProtocol.""" assert isinstance(store, KeyValueProtocol) is True - def test_empty_get(self, store: BaseStore): """Tests that the get method returns None from an empty store.""" - assert store.get(collection='test', key='test') is None - + assert store.get(collection="test", key="test") is None def test_empty_put(self, store: BaseStore): """Tests that the put method does not raise an exception when called on a new store.""" - store.put(collection='test', key='test', value={'test': 'test'}) - + store.put(collection="test", key="test", value={"test": "test"}) def test_empty_ttl(self, store: BaseStore): """Tests that the ttl method returns None from an empty store.""" - ttl = store.ttl(collection='test', key='test') + ttl = store.ttl(collection="test", key="test") assert ttl == (None, None) - def test_put_serialization_errors(self, store: BaseStore): """Tests that the put method raises SerializationError for non-JSON-serializable Pydantic types.""" with pytest.raises(SerializationError): - store.put(collection='test', key='test', value={'test': AnyHttpUrl('https://test.com')}) - + store.put(collection="test", key="test", value={"test": AnyHttpUrl("https://test.com")}) def test_get_put_get(self, store: BaseStore): - assert store.get(collection='test', key='test') is None - store.put(collection='test', key='test', value={'test': 'test'}) - assert store.get(collection='test', key='test') == {'test': 'test'} - + assert store.get(collection="test", key="test") is None + store.put(collection="test", key="test", value={"test": "test"}) + assert store.get(collection="test", key="test") == {"test": "test"} @PositiveCases.parametrize(cases=SIMPLE_CASES) def test_models_put_get(self, store: BaseStore, data: dict[str, Any], json: str, round_trip: dict[str, Any]): - store.put(collection='test', key='test', value=data) - retrieved_data = store.get(collection='test', key='test') + store.put(collection="test", key="test", value=data) + retrieved_data = store.get(collection="test", key="test") assert retrieved_data is not None assert retrieved_data == round_trip - @NegativeCases.parametrize(cases=NEGATIVE_SIMPLE_CASES) def test_negative_models_put_get(self, store: BaseStore, data: dict[str, Any], error: type[Exception]): with pytest.raises(error): - store.put(collection='test', key='test', value=data) - + store.put(collection="test", key="test", value=data) @PositiveCases.parametrize(cases=[LARGE_DATA_CASES]) def test_get_large_put_get(self, store: BaseStore, data: dict[str, Any], json: str, round_trip: dict[str, Any]): - store.put(collection='test', key='test', value=data) - assert store.get(collection='test', key='test') == round_trip - + store.put(collection="test", key="test", value=data) + assert store.get(collection="test", key="test") == round_trip def test_put_many_get(self, store: BaseStore): - store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) - assert store.get(collection='test', key='test') == {'test': 'test'} - assert store.get(collection='test', key='test_2') == {'test': 'test_2'} - + store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) + assert store.get(collection="test", key="test") == {"test": "test"} + assert store.get(collection="test", key="test_2") == {"test": "test_2"} def test_put_many_get_many(self, store: BaseStore): - store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) - assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] - + store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) + assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] def test_put_put_get_many(self, store: BaseStore): - store.put(collection='test', key='test', value={'test': 'test'}) - store.put(collection='test', key='test_2', value={'test': 'test_2'}) - assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] - + store.put(collection="test", key="test", value={"test": "test"}) + store.put(collection="test", key="test_2", value={"test": "test_2"}) + assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] def test_put_put_get_many_missing_one(self, store: BaseStore): - store.put(collection='test', key='test', value={'test': 'test'}) - store.put(collection='test', key='test_2', value={'test': 'test_2'}) - assert store.get_many(collection='test', keys=['test', 'test_2', 'test_3']) == [{'test': 'test'}, {'test': 'test_2'}, None] - + store.put(collection="test", key="test", value={"test": "test"}) + store.put(collection="test", key="test_2", value={"test": "test_2"}) + assert store.get_many(collection="test", keys=["test", "test_2", "test_3"]) == [{"test": "test"}, {"test": "test_2"}, None] def test_put_get_delete_get(self, store: BaseStore): - store.put(collection='test', key='test', value={'test': 'test'}) - assert store.get(collection='test', key='test') == {'test': 'test'} - assert store.delete(collection='test', key='test') - assert store.get(collection='test', key='test') is None - + store.put(collection="test", key="test", value={"test": "test"}) + assert store.get(collection="test", key="test") == {"test": "test"} + assert store.delete(collection="test", key="test") + assert store.get(collection="test", key="test") is None def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): - store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) - assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] - assert store.delete_many(collection='test', keys=['test', 'test_2']) == 2 - assert store.get_many(collection='test', keys=['test', 'test_2']) == [None, None] - + store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) + assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] + assert store.delete_many(collection="test", keys=["test", "test_2"]) == 2 + assert store.get_many(collection="test", keys=["test", "test_2"]) == [None, None] def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): - store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) - assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] - assert store.delete_many(collection='test', keys=['test', 'test_2']) == 2 - assert store.get_many(collection='test', keys=['test', 'test_2']) == [None, None] - + store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) + assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] + assert store.delete_many(collection="test", keys=["test", "test_2"]) == 2 + assert store.get_many(collection="test", keys=["test", "test_2"]) == [None, None] def test_put_many_tuple_get_many(self, store: BaseStore): - store.put_many(collection='test', keys=['test', 'test_2'], values=({'test': 'test'}, {'test': 'test_2'})) - assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] - + store.put_many(collection="test", keys=["test", "test_2"], values=({"test": "test"}, {"test": "test_2"})) + assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] def test_delete(self, store: BaseStore): - assert store.delete(collection='test', key='test') is False - + assert store.delete(collection="test", key="test") is False def test_put_delete_delete(self, store: BaseStore): - store.put(collection='test', key='test', value={'test': 'test'}) - assert store.delete(collection='test', key='test') - assert store.delete(collection='test', key='test') is False - + store.put(collection="test", key="test", value={"test": "test"}) + assert store.delete(collection="test", key="test") + assert store.delete(collection="test", key="test") is False def test_delete_many(self, store: BaseStore): - assert store.delete_many(collection='test', keys=['test', 'test_2']) == 0 - + assert store.delete_many(collection="test", keys=["test", "test_2"]) == 0 def test_put_delete_many(self, store: BaseStore): - store.put(collection='test', key='test', value={'test': 'test'}) - assert store.delete_many(collection='test', keys=['test', 'test_2']) == 1 - + store.put(collection="test", key="test", value={"test": "test"}) + assert store.delete_many(collection="test", keys=["test", "test_2"]) == 1 def test_delete_many_delete_many(self, store: BaseStore): - store.put(collection='test', key='test', value={'test': 'test'}) - assert store.delete_many(collection='test', keys=['test', 'test_2']) == 1 - assert store.delete_many(collection='test', keys=['test', 'test_2']) == 0 - + store.put(collection="test", key="test", value={"test": "test"}) + assert store.delete_many(collection="test", keys=["test", "test_2"]) == 1 + assert store.delete_many(collection="test", keys=["test", "test_2"]) == 0 def test_get_put_get_delete_get(self, store: BaseStore): """Tests that the get, put, delete, and get methods work together to store and retrieve a value from an empty store.""" - - assert store.get(collection='test', key='test') is None - - store.put(collection='test', key='test', value={'test': 'test'}) - - assert store.get(collection='test', key='test') == {'test': 'test'} - - assert store.delete(collection='test', key='test') - - assert store.get(collection='test', key='test') is None - + + assert store.get(collection="test", key="test") is None + + store.put(collection="test", key="test", value={"test": "test"}) + + assert store.get(collection="test", key="test") == {"test": "test"} + + assert store.delete(collection="test", key="test") + + assert store.get(collection="test", key="test") is None def test_get_put_get_put_delete_get(self, store: BaseStore): """Tests that the get, put, get, put, delete, and get methods work together to store and retrieve a value from an empty store.""" - store.put(collection='test', key='test', value={'test': 'test'}) - assert store.get(collection='test', key='test') == {'test': 'test'} - - store.put(collection='test', key='test', value={'test': 'test_2'}) - - assert store.get(collection='test', key='test') == {'test': 'test_2'} - assert store.delete(collection='test', key='test') - assert store.get(collection='test', key='test') is None - + store.put(collection="test", key="test", value={"test": "test"}) + assert store.get(collection="test", key="test") == {"test": "test"} + + store.put(collection="test", key="test", value={"test": "test_2"}) + + assert store.get(collection="test", key="test") == {"test": "test_2"} + assert store.delete(collection="test", key="test") + assert store.get(collection="test", key="test") is None def test_put_many_delete_delete_get_many(self, store: BaseStore): - store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'test'}, {'test': 'test_2'}]) - assert store.get_many(collection='test', keys=['test', 'test_2']) == [{'test': 'test'}, {'test': 'test_2'}] - assert store.delete(collection='test', key='test') - assert store.delete(collection='test', key='test_2') - assert store.get_many(collection='test', keys=['test', 'test_2']) == [None, None] - + store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "test"}, {"test": "test_2"}]) + assert store.get_many(collection="test", keys=["test", "test_2"]) == [{"test": "test"}, {"test": "test_2"}] + assert store.delete(collection="test", key="test") + assert store.delete(collection="test", key="test_2") + assert store.get_many(collection="test", keys=["test", "test_2"]) == [None, None] def test_put_ttl_get_ttl(self, store: BaseStore): """Tests that the put and get ttl methods work together to store and retrieve a ttl from an empty store.""" - store.put(collection='test', key='test', value={'test': 'test'}, ttl=100) - (value, ttl) = store.ttl(collection='test', key='test') - - assert value == {'test': 'test'} + store.put(collection="test", key="test", value={"test": "test"}, ttl=100) + (value, ttl) = store.ttl(collection="test", key="test") + + assert value == {"test": "test"} assert ttl is not None assert ttl == IsFloat(approx=100) - def test_negative_ttl(self, store: BaseStore): """Tests that a negative ttl will return None when getting the key.""" with pytest.raises(InvalidTTLError): - store.put(collection='test', key='test', value={'test': 'test'}, ttl=-100) - + store.put(collection="test", key="test", value={"test": "test"}, ttl=-100) @pytest.mark.timeout(10) def test_put_expired_get_none(self, store: BaseStore): """Tests that a put call with a negative ttl will return None when getting the key.""" - store.put(collection='test_collection', key='test_key', value={'test': 'test'}, ttl=2) - assert store.get(collection='test_collection', key='test_key') is not None + store.put(collection="test_collection", key="test_key", value={"test": "test"}, ttl=2) + assert store.get(collection="test_collection", key="test_key") is not None sleep(seconds=1) - + for _ in range(8): sleep(seconds=0.25) - if store.get(collection='test_collection', key='test_key') is None: + if store.get(collection="test_collection", key="test_key") is None: # pass the test return - - pytest.fail('put_expired_get_none test failed, entry did not expire') - + + pytest.fail("put_expired_get_none test failed, entry did not expire") def test_long_collection_name(self, store: BaseStore): """Tests that a long collection name will not raise an error.""" - store.put(collection='test_collection' * 100, key='test_key', value={'test': 'test'}) - assert store.get(collection='test_collection' * 100, key='test_key') == {'test': 'test'} - + store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"}) + assert store.get(collection="test_collection" * 100, key="test_key") == {"test": "test"} def test_special_characters_in_collection_name(self, store: BaseStore): """Tests that a special characters in the collection name will not raise an error.""" - store.put(collection='test_collection!@#$%^&*()', key='test_key', value={'test': 'test'}) - assert store.get(collection='test_collection!@#$%^&*()', key='test_key') == {'test': 'test'} - + store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) + assert store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} def test_long_key_name(self, store: BaseStore): """Tests that a long key name will not raise an error.""" - store.put(collection='test_collection', key='test_key' * 100, value={'test': 'test'}) - assert store.get(collection='test_collection', key='test_key' * 100) == {'test': 'test'} - + store.put(collection="test_collection", key="test_key" * 100, value={"test": "test"}) + assert store.get(collection="test_collection", key="test_key" * 100) == {"test": "test"} def test_special_characters_in_key_name(self, store: BaseStore): """Tests that a special characters in the key name will not raise an error.""" - store.put(collection='test_collection', key='test_key!@#$%^&*()', value={'test': 'test'}) - assert store.get(collection='test_collection', key='test_key!@#$%^&*()') == {'test': 'test'} - + store.put(collection="test_collection", key="test_key!@#$%^&*()", value={"test": "test"}) + assert store.get(collection="test_collection", key="test_key!@#$%^&*()") == {"test": "test"} @pytest.mark.timeout(20) def test_not_unbounded(self, store: BaseStore): """Tests that the store is not unbounded.""" - + for i in range(1000): - value = hashlib.sha256(f'test_{i}'.encode()).hexdigest() - store.put(collection='test_collection', key=f'test_key_{i}', value={'test': value}) - - assert store.get(collection='test_collection', key='test_key_0') is None - assert store.get(collection='test_collection', key='test_key_999') is not None - - - @pytest.mark.skipif(condition=not running_in_event_loop(), reason='Cannot run concurrent operations outside of event loop') + value = hashlib.sha256(f"test_{i}".encode()).hexdigest() + store.put(collection="test_collection", key=f"test_key_{i}", value={"test": value}) + + assert store.get(collection="test_collection", key="test_key_0") is None + assert store.get(collection="test_collection", key="test_key_999") is not None + + @pytest.mark.skipif(condition=not running_in_event_loop(), reason="Cannot run concurrent operations outside of event loop") def test_concurrent_operations(self, store: BaseStore): """Tests that the store can handle concurrent operations.""" - def worker(store: BaseStore, worker_id: int): for i in range(10): - assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') is None - - store.put(collection='test_collection', key=f'test_{worker_id}_{i}', value={'test': f'test_{i}'}) - assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') == {'test': f'test_{i}'} - - store.put(collection='test_collection', key=f'test_{worker_id}_{i}', value={'test': f'test_{i}_2'}) - assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') == {'test': f'test_{i}_2'} - - assert store.delete(collection='test_collection', key=f'test_{worker_id}_{i}') - assert store.get(collection='test_collection', key=f'test_{worker_id}_{i}') is None - + assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") is None + + store.put(collection="test_collection", key=f"test_{worker_id}_{i}", value={"test": f"test_{i}"}) + assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") == {"test": f"test_{i}"} + + store.put(collection="test_collection", key=f"test_{worker_id}_{i}", value={"test": f"test_{i}_2"}) + assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") == {"test": f"test_{i}_2"} + + assert store.delete(collection="test_collection", key=f"test_{worker_id}_{i}") + assert store.get(collection="test_collection", key=f"test_{worker_id}_{i}") is None + _ = gather(*[worker(store, worker_id) for worker_id in range(5)]) - @pytest.mark.timeout(15) def test_minimum_put_many_get_many_performance(self, store: BaseStore): """Tests that the store meets minimum performance requirements.""" - keys = [f'test_{i}' for i in range(10)] - values = [{'test': f'test_{i}'} for i in range(10)] - store.put_many(collection='test_collection', keys=keys, values=values) - assert store.get_many(collection='test_collection', keys=keys) == values - + keys = [f"test_{i}" for i in range(10)] + values = [{"test": f"test_{i}"} for i in range(10)] + store.put_many(collection="test_collection", keys=keys, values=values) + assert store.get_many(collection="test_collection", keys=keys) == values @pytest.mark.timeout(15) def test_minimum_put_many_delete_many_performance(self, store: BaseStore): """Tests that the store meets minimum performance requirements.""" - keys = [f'test_{i}' for i in range(10)] - values = [{'test': f'test_{i}'} for i in range(10)] - store.put_many(collection='test_collection', keys=keys, values=values) - assert store.delete_many(collection='test_collection', keys=keys) == 10 + keys = [f"test_{i}" for i in range(10)] + values = [{"test": f"test_{i}"} for i in range(10)] + store.put_many(collection="test_collection", keys=keys, values=values) + assert store.delete_many(collection="test_collection", keys=keys) == 10 class ContextManagerStoreTestMixin: - - @pytest.fixture(params=[True, False], ids=['with_ctx_manager', 'no_ctx_manager'], autouse=True) - def enter_exit_store(self, request: pytest.FixtureRequest, store: BaseContextManagerStore) -> Generator[BaseContextManagerStore, None, None]: + @pytest.fixture(params=[True, False], ids=["with_ctx_manager", "no_ctx_manager"], autouse=True) + def enter_exit_store( + self, request: pytest.FixtureRequest, store: BaseContextManagerStore + ) -> Generator[BaseContextManagerStore, None, None]: context_manager = request.param # pyright: ignore[reportAny] - + if context_manager: with store: yield store diff --git a/key-value/key-value-sync/tests/code_gen/stores/conftest.py b/key-value/key-value-sync/tests/code_gen/stores/conftest.py index 322fdaa7..114e8c57 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/conftest.py +++ b/key-value/key-value-sync/tests/code_gen/stores/conftest.py @@ -21,5 +21,5 @@ def now_plus(seconds: int) -> datetime: return now() + timedelta(seconds=seconds) -def is_around(value: float, delta: float=1) -> bool: +def is_around(value: float, delta: float = 1) -> bool: return value - delta < value < value + delta diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py index 89abc72a..994738a2 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py @@ -14,16 +14,14 @@ class TestDiskStore(ContextManagerStoreTestMixin, BaseStoreTests): - - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def disk_store(self) -> Generator[DiskStore, None, None]: with tempfile.TemporaryDirectory() as temp_dir: yield DiskStore(directory=temp_dir, max_size=TEST_SIZE_LIMIT) - @override @pytest.fixture def store(self, disk_store: DiskStore) -> DiskStore: disk_store._cache.clear() # pyright: ignore[reportPrivateUsage] - + return disk_store diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py index b3944761..e6341075 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py @@ -15,17 +15,15 @@ class TestMultiDiskStore(ContextManagerStoreTestMixin, BaseStoreTests): - - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def multi_disk_store(self) -> Generator[MultiDiskStore, None, None]: with tempfile.TemporaryDirectory() as temp_dir: yield MultiDiskStore(base_directory=Path(temp_dir), max_size=TEST_SIZE_LIMIT) - @override @pytest.fixture def store(self, multi_disk_store: MultiDiskStore) -> MultiDiskStore: for collection in multi_disk_store._cache: # pyright: ignore[reportPrivateUsage] multi_disk_store._cache[collection].clear() # pyright: ignore[reportPrivateUsage] - + return multi_disk_store diff --git a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py index 033ef522..69851a03 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py +++ b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py @@ -19,15 +19,15 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin TEST_SIZE_LIMIT = 1 * 1024 * 1024 # 1MB -ES_HOST = 'localhost' +ES_HOST = "localhost" ES_PORT = 9200 -ES_URL = f'http://{ES_HOST}:{ES_PORT}' +ES_URL = f"http://{ES_HOST}:{ES_PORT}" ES_CONTAINER_PORT = 9200 WAIT_FOR_ELASTICSEARCH_TIMEOUT = 30 - # Released Apr 2025 +# Released Apr 2025 # Released Oct 2025 -ELASTICSEARCH_VERSIONS_TO_TEST = ['9.0.0', '9.2.0'] +ELASTICSEARCH_VERSIONS_TO_TEST = ["9.0.0", "9.2.0"] def get_elasticsearch_client() -> Elasticsearch: @@ -36,13 +36,13 @@ def get_elasticsearch_client() -> Elasticsearch: def ping_elasticsearch() -> bool: es_client: Elasticsearch = get_elasticsearch_client() - + with es_client: return es_client.ping() def cleanup_elasticsearch_indices(elasticsearch_client: Elasticsearch): - indices = elasticsearch_client.options(ignore_status=404).indices.get(index='kv-store-e2e-test-*') + indices = elasticsearch_client.options(ignore_status=404).indices.get(index="kv-store-e2e-test-*") for index in indices: _ = elasticsearch_client.options(ignore_status=404).indices.delete(index=index) @@ -54,14 +54,22 @@ class ElasticsearchFailedToStartError(Exception): def test_managed_entry_document_conversion(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection='test_collection', key='test_key', managed_entry=managed_entry) - - assert document == snapshot({'collection': 'test_collection', 'key': 'test_key', 'value': {'string': '{"test": "test"}'}, 'created_at': '2025-01-01T00:00:00+00:00', 'expires_at': '2025-01-01T00:00:10+00:00'}) - + + managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry) + + assert document == snapshot( + { + "collection": "test_collection", + "key": "test_key", + "value": {"string": '{"test": "test"}'}, + "created_at": "2025-01-01T00:00:00+00:00", + "expires_at": "2025-01-01T00:00:10+00:00", + } + ) + round_trip_managed_entry = source_to_managed_entry(source=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -71,14 +79,22 @@ def test_managed_entry_document_conversion(): def test_managed_entry_document_conversion_native_storage(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection='test_collection', key='test_key', managed_entry=managed_entry, native_storage=True) - - assert document == snapshot({'collection': 'test_collection', 'key': 'test_key', 'value': {'flattened': {'test': 'test'}}, 'created_at': '2025-01-01T00:00:00+00:00', 'expires_at': '2025-01-01T00:00:10+00:00'}) - + + managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry, native_storage=True) + + assert document == snapshot( + { + "collection": "test_collection", + "key": "test_key", + "value": {"flattened": {"test": "test"}}, + "created_at": "2025-01-01T00:00:00+00:00", + "expires_at": "2025-01-01T00:00:10+00:00", + } + ) + round_trip_managed_entry = source_to_managed_entry(source=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -86,133 +102,160 @@ def test_managed_entry_document_conversion_native_storage(): class BaseTestElasticsearchStore(ContextManagerStoreTestMixin, BaseStoreTests): - - @pytest.fixture(autouse=True, scope='session', params=ELASTICSEARCH_VERSIONS_TO_TEST) + @pytest.fixture(autouse=True, scope="session", params=ELASTICSEARCH_VERSIONS_TO_TEST) def setup_elasticsearch(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - es_image = f'docker.elastic.co/elasticsearch/elasticsearch:{version}' - - with docker_container(f'elasticsearch-test-{version}', es_image, {str(ES_CONTAINER_PORT): ES_PORT}, {'discovery.type': 'single-node', 'xpack.security.enabled': 'false'}): + es_image = f"docker.elastic.co/elasticsearch/elasticsearch:{version}" + + with docker_container( + f"elasticsearch-test-{version}", + es_image, + {str(ES_CONTAINER_PORT): ES_PORT}, + {"discovery.type": "single-node", "xpack.security.enabled": "false"}, + ): if not wait_for_true(bool_fn=ping_elasticsearch, tries=WAIT_FOR_ELASTICSEARCH_TIMEOUT, wait_time=2): - msg = f'Elasticsearch {version} failed to start' + msg = f"Elasticsearch {version} failed to start" raise ElasticsearchFailedToStartError(msg) - + yield - @pytest.fixture def es_client(self) -> Generator[Elasticsearch, None, None]: with Elasticsearch(hosts=[ES_URL]) as es_client: yield es_client - @pytest.fixture(autouse=True) def cleanup_elasticsearch_indices(self, es_client: Elasticsearch): cleanup_elasticsearch_indices(elasticsearch_client=es_client) yield cleanup_elasticsearch_indices(elasticsearch_client=es_client) - - @pytest.mark.skip(reason='Distributed Caches are unbounded') + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override - def test_not_unbounded(self, store: BaseStore): - ... - + def test_not_unbounded(self, store: BaseStore): ... - @pytest.mark.skip(reason='Skip concurrent tests on distributed caches') + @pytest.mark.skip(reason="Skip concurrent tests on distributed caches") @override - def test_concurrent_operations(self, store: BaseStore): - ... - + def test_concurrent_operations(self, store: BaseStore): ... def test_put_put_two_indices(self, store: ElasticsearchStore, es_client: Elasticsearch): - store.put(collection='test_collection', key='test_key', value={'test': 'test'}) - store.put(collection='test_collection_2', key='test_key', value={'test': 'test'}) - assert store.get(collection='test_collection', key='test_key') == {'test': 'test'} - assert store.get(collection='test_collection_2', key='test_key') == {'test': 'test'} - - indices = es_client.options(ignore_status=404).indices.get(index='kv-store-e2e-test-*') + store.put(collection="test_collection", key="test_key", value={"test": "test"}) + store.put(collection="test_collection_2", key="test_key", value={"test": "test"}) + assert store.get(collection="test_collection", key="test_key") == {"test": "test"} + assert store.get(collection="test_collection_2", key="test_key") == {"test": "test"} + + indices = es_client.options(ignore_status=404).indices.get(index="kv-store-e2e-test-*") assert len(indices.body) == 2 - assert 'kv-store-e2e-test-test_collection' in indices - assert 'kv-store-e2e-test-test_collection_2' in indices + assert "kv-store-e2e-test-test_collection" in indices + assert "kv-store-e2e-test-test_collection_2" in indices -@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestElasticsearchStoreNativeMode(BaseTestElasticsearchStore): """Test Elasticsearch store in native mode (i.e. it stores flattened objects)""" - @override @pytest.fixture def store(self) -> ElasticsearchStore: - return ElasticsearchStore(url=ES_URL, index_prefix='kv-store-e2e-test', native_storage=True) - + return ElasticsearchStore(url=ES_URL, index_prefix="kv-store-e2e-test", native_storage=True) def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify values are stored as flattened objects, not JSON strings""" - store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) - + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + # Check raw Elasticsearch document using public sanitization methods # Note: We need to access these internal methods for testing the storage format - index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key='test_key') # pyright: ignore[reportPrivateUsage] - + index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key="test_key") # pyright: ignore[reportPrivateUsage] + response = es_client.get(index=index_name, id=doc_id) - assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'flattened': {'name': 'Alice', 'age': 30}}, 'created_at': IsStr(min_length=20, max_length=40)}) - + assert response.body["_source"] == snapshot( + { + "collection": "test", + "key": "test_key", + "value": {"flattened": {"name": "Alice", "age": 30}}, + "created_at": IsStr(min_length=20, max_length=40), + } + ) + # Test with TTL - store.put(collection='test', key='test_key', value={'name': 'Bob', 'age': 25}, ttl=10) + store.put(collection="test", key="test_key", value={"name": "Bob", "age": 25}, ttl=10) response = es_client.get(index=index_name, id=doc_id) - assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'flattened': {'name': 'Bob', 'age': 25}}, 'created_at': IsStr(min_length=20, max_length=40), 'expires_at': IsStr(min_length=20, max_length=40)}) - + assert response.body["_source"] == snapshot( + { + "collection": "test", + "key": "test_key", + "value": {"flattened": {"name": "Bob", "age": 25}}, + "created_at": IsStr(min_length=20, max_length=40), + "expires_at": IsStr(min_length=20, max_length=40), + } + ) def test_migration_from_non_native_mode(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify native mode can read a document with stringified data""" - index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key='legacy_key') # pyright: ignore[reportPrivateUsage] - - es_client.index(index=index_name, id=doc_id, body={'collection': 'test', 'key': 'legacy_key', 'value': {'string': '{"legacy": "data"}'}}) + index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key="legacy_key") # pyright: ignore[reportPrivateUsage] + + es_client.index( + index=index_name, id=doc_id, body={"collection": "test", "key": "legacy_key", "value": {"string": '{"legacy": "data"}'}} + ) es_client.indices.refresh(index=index_name) - - result = store.get(collection='test', key='legacy_key') - assert result == snapshot({'legacy': 'data'}) + + result = store.get(collection="test", key="legacy_key") + assert result == snapshot({"legacy": "data"}) -@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestElasticsearchStoreNonNativeMode(BaseTestElasticsearchStore): """Test Elasticsearch store in non-native mode (i.e. it stores stringified JSON values)""" - @override @pytest.fixture def store(self) -> ElasticsearchStore: - return ElasticsearchStore(url=ES_URL, index_prefix='kv-store-e2e-test', native_storage=False) - + return ElasticsearchStore(url=ES_URL, index_prefix="kv-store-e2e-test", native_storage=False) def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify values are stored as JSON strings""" - store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) - - index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key='test_key') # pyright: ignore[reportPrivateUsage] - + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key="test_key") # pyright: ignore[reportPrivateUsage] + response = es_client.get(index=index_name, id=doc_id) - assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'string': '{"age": 30, "name": "Alice"}'}, 'created_at': IsStr(min_length=20, max_length=40)}) - + assert response.body["_source"] == snapshot( + { + "collection": "test", + "key": "test_key", + "value": {"string": '{"age": 30, "name": "Alice"}'}, + "created_at": IsStr(min_length=20, max_length=40), + } + ) + # Test with TTL - store.put(collection='test', key='test_key', value={'name': 'Bob', 'age': 25}, ttl=10) + store.put(collection="test", key="test_key", value={"name": "Bob", "age": 25}, ttl=10) response = es_client.get(index=index_name, id=doc_id) - assert response.body['_source'] == snapshot({'collection': 'test', 'key': 'test_key', 'value': {'string': '{"age": 25, "name": "Bob"}'}, 'created_at': IsStr(min_length=20, max_length=40), 'expires_at': IsStr(min_length=20, max_length=40)}) - + assert response.body["_source"] == snapshot( + { + "collection": "test", + "key": "test_key", + "value": {"string": '{"age": 25, "name": "Bob"}'}, + "created_at": IsStr(min_length=20, max_length=40), + "expires_at": IsStr(min_length=20, max_length=40), + } + ) def test_migration_from_native_mode(self, store: ElasticsearchStore, es_client: Elasticsearch): """Verify non-native mode can read native mode data""" - index_name = store._sanitize_index_name(collection='test') # pyright: ignore[reportPrivateUsage] - doc_id = store._sanitize_document_id(key='legacy_key') # pyright: ignore[reportPrivateUsage] - - es_client.index(index=index_name, id=doc_id, body={'collection': 'test', 'key': 'legacy_key', 'value': {'flattened': {'name': 'Alice', 'age': 30}}}) - + index_name = store._sanitize_index_name(collection="test") # pyright: ignore[reportPrivateUsage] + doc_id = store._sanitize_document_id(key="legacy_key") # pyright: ignore[reportPrivateUsage] + + es_client.index( + index=index_name, + id=doc_id, + body={"collection": "test", "key": "legacy_key", "value": {"flattened": {"name": "Alice", "age": 30}}}, + ) + es_client.indices.refresh(index=index_name) - - result = store.get(collection='test', key='legacy_key') - assert result == snapshot({'name': 'Alice', 'age': 30}) + + result = store.get(collection="test", key="legacy_key") + assert result == snapshot({"name": "Alice", "age": 30}) diff --git a/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py b/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py index 1c586f2c..c53c4f2b 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py +++ b/key-value/key-value-sync/tests/code_gen/stores/keyring/test_keyring.py @@ -13,28 +13,24 @@ from tests.code_gen.stores.base import BaseStoreTests -@pytest.mark.skipif(condition=detect_on_linux(), reason='KeyringStore is not available on Linux CI') +@pytest.mark.skipif(condition=detect_on_linux(), reason="KeyringStore is not available on Linux CI") class TestKeychainStore(BaseStoreTests): - @override @pytest.fixture def store(self) -> KeyringStore: # Use a test-specific service name to avoid conflicts - store = KeyringStore(service_name='py-key-value-test') - store.delete_many(collection='test', keys=['test', 'test_2']) - store.delete_many(collection='test_collection', keys=['test_key']) - + store = KeyringStore(service_name="py-key-value-test") + store.delete_many(collection="test", keys=["test", "test_2"]) + store.delete_many(collection="test_collection", keys=["test_key"]) + return store - @override - @pytest.mark.skip(reason='We do not test boundedness of keyring stores') - def test_not_unbounded(self, store: BaseStore): - ... - + @pytest.mark.skip(reason="We do not test boundedness of keyring stores") + def test_not_unbounded(self, store: BaseStore): ... @override - @pytest.mark.skipif(condition=detect_on_windows(), reason='Keyrings do not support large values on Windows') + @pytest.mark.skipif(condition=detect_on_windows(), reason="Keyrings do not support large values on Windows") @PositiveCases.parametrize(cases=[LARGE_DATA_CASES]) def test_get_large_put_get(self, store: BaseStore, data: dict[str, Any], json: str, round_trip: dict[str, Any]): super().test_get_large_put_get(store, data, json, round_trip=round_trip) diff --git a/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py b/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py index 1783124e..631cf255 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py +++ b/key-value/key-value-sync/tests/code_gen/stores/memory/test_memory.py @@ -9,13 +9,11 @@ class TestMemoryStore(BaseStoreTests): - @override @pytest.fixture def store(self) -> MemoryStore: return MemoryStore(max_entries_per_collection=500) - def test_seed(self): - store = MemoryStore(max_entries_per_collection=500, seed={'test_collection': {'test_key': {'obj_key': 'obj_value'}}}) - assert store.get(key='test_key', collection='test_collection') == {'obj_key': 'obj_value'} + store = MemoryStore(max_entries_per_collection=500, seed={"test_collection": {"test_key": {"obj_key": "obj_value"}}}) + assert store.get(key="test_key", collection="test_collection") == {"obj_key": "obj_value"} diff --git a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py index 7c7bc119..a9a3a9a3 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py @@ -22,14 +22,14 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin # MongoDB test configuration -MONGODB_HOST = 'localhost' +MONGODB_HOST = "localhost" MONGODB_HOST_PORT = 27017 -MONGODB_TEST_DB = 'kv-store-adapter-tests' +MONGODB_TEST_DB = "kv-store-adapter-tests" WAIT_FOR_MONGODB_TIMEOUT = 30 - # Older supported version +# Older supported version # Latest stable version -MONGODB_VERSIONS_TO_TEST = ['5.0', '8.0'] +MONGODB_VERSIONS_TO_TEST = ["5.0", "8.0"] def ping_mongodb() -> bool: @@ -38,7 +38,7 @@ def ping_mongodb() -> bool: _ = client.list_database_names() except Exception: return False - + return True @@ -49,14 +49,21 @@ class MongoDBFailedToStartError(Exception): def test_managed_entry_document_conversion_native_mode(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key='test', managed_entry=managed_entry, native_storage=True) - - assert document == snapshot({'key': 'test', 'value': {'object': {'test': 'test'}}, 'created_at': datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), 'expires_at': datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc)}) - + + managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) + + assert document == snapshot( + { + "key": "test", + "value": {"object": {"test": "test"}}, + "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), + } + ) + round_trip_managed_entry = document_to_managed_entry(document=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -66,14 +73,21 @@ def test_managed_entry_document_conversion_native_mode(): def test_managed_entry_document_conversion_legacy_mode(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key='test', managed_entry=managed_entry, native_storage=False) - - assert document == snapshot({'key': 'test', 'value': {'string': '{"test": "test"}'}, 'created_at': datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), 'expires_at': datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc)}) - + + managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) + document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) + + assert document == snapshot( + { + "key": "test", + "value": {"string": '{"test": "test"}'}, + "created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + "expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc), + } + ) + round_trip_managed_entry = document_to_managed_entry(document=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -87,110 +101,114 @@ def clean_mongodb_database(store: MongoDBStore) -> None: class BaseMongoDBStoreTests(ContextManagerStoreTestMixin, BaseStoreTests): """Base class for MongoDB store tests.""" - - @pytest.fixture(autouse=True, scope='session', params=MONGODB_VERSIONS_TO_TEST) + @pytest.fixture(autouse=True, scope="session", params=MONGODB_VERSIONS_TO_TEST) def setup_mongodb(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container(f'mongodb-test-{version}', f'mongo:{version}', {str(MONGODB_HOST_PORT): MONGODB_HOST_PORT}): + + with docker_container(f"mongodb-test-{version}", f"mongo:{version}", {str(MONGODB_HOST_PORT): MONGODB_HOST_PORT}): if not wait_for_true(bool_fn=ping_mongodb, tries=WAIT_FOR_MONGODB_TIMEOUT, wait_time=1): - msg = f'MongoDB {version} failed to start' + msg = f"MongoDB {version} failed to start" raise MongoDBFailedToStartError(msg) - + yield - - @pytest.mark.skip(reason='Distributed Caches are unbounded') + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override - def test_not_unbounded(self, store: BaseStore): - ... - + def test_not_unbounded(self, store: BaseStore): ... def test_mongodb_collection_name_sanitization(self, store: MongoDBStore): """Tests that a special characters in the collection name will not raise an error.""" - store.put(collection='test_collection!@#$%^&*()', key='test_key', value={'test': 'test'}) - assert store.get(collection='test_collection!@#$%^&*()', key='test_key') == {'test': 'test'} - + store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"}) + assert store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"} + collections = store.collections() - assert collections == snapshot(['test_collection_-daf4a2ec']) + assert collections == snapshot(["test_collection_-daf4a2ec"]) -@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not available') +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") class TestMongoDBStoreNativeMode(BaseMongoDBStoreTests): """Test MongoDBStore with native_storage=True (default).""" - @override @pytest.fixture def store(self, setup_mongodb: None) -> MongoDBStore: - store = MongoDBStore(url=f'mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}', db_name=f'{MONGODB_TEST_DB}-native', native_storage=True) - + store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=f"{MONGODB_TEST_DB}-native", native_storage=True) + clean_mongodb_database(store=store) - + return store - def test_value_stored_as_bson_dict(self, store: MongoDBStore): """Verify values are stored as BSON dicts, not JSON strings.""" - store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) - + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + # Get the raw MongoDB document - store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - doc = collection.find_one({'key': 'test_key'}) - - assert doc == snapshot({'_id': IsInstance(expected_type=ObjectId), 'key': 'test_key', 'created_at': IsDatetime(), 'value': {'object': {'name': 'Alice', 'age': 30}}}) - + doc = collection.find_one({"key": "test_key"}) + + assert doc == snapshot( + { + "_id": IsInstance(expected_type=ObjectId), + "key": "test_key", + "created_at": IsDatetime(), + "value": {"object": {"name": "Alice", "age": 30}}, + } + ) def test_migration_from_legacy_mode(self, store: MongoDBStore): """Verify native mode can read legacy JSON string data.""" - store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - - collection.insert_one({'key': 'legacy_key', 'value': {'string': '{"legacy": "data"}'}}) - - result = store.get(collection='test', key='legacy_key') - assert result == {'legacy': 'data'} + + collection.insert_one({"key": "legacy_key", "value": {"string": '{"legacy": "data"}'}}) + + result = store.get(collection="test", key="legacy_key") + assert result == {"legacy": "data"} -@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not available') +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") class TestMongoDBStoreNonNativeMode(BaseMongoDBStoreTests): """Test MongoDBStore with native_storage=False (legacy mode) for backward compatibility.""" - @override @pytest.fixture def store(self, setup_mongodb: None) -> MongoDBStore: - store = MongoDBStore(url=f'mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}', db_name=MONGODB_TEST_DB, native_storage=False) - + store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB, native_storage=False) + clean_mongodb_database(store=store) - + return store - def test_value_stored_as_json(self, store: MongoDBStore): """Verify values are stored as JSON strings.""" - store.put(collection='test', key='test_key', value={'name': 'Alice', 'age': 30}) - + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + # Get the raw MongoDB document - store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - doc = collection.find_one({'key': 'test_key'}) - - assert doc == snapshot({'_id': IsInstance(expected_type=ObjectId), 'key': 'test_key', 'created_at': IsDatetime(), 'value': {'string': '{"age": 30, "name": "Alice"}'}}) - + doc = collection.find_one({"key": "test_key"}) + + assert doc == snapshot( + { + "_id": IsInstance(expected_type=ObjectId), + "key": "test_key", + "created_at": IsDatetime(), + "value": {"string": '{"age": 30, "name": "Alice"}'}, + } + ) def test_migration_from_native_mode(self, store: MongoDBStore): """Verify non-native mode can read native mode data.""" - store._setup_collection(collection='test') # pyright: ignore[reportPrivateUsage] - sanitized_collection = store._sanitize_collection_name(collection='test') # pyright: ignore[reportPrivateUsage] + store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage] + sanitized_collection = store._sanitize_collection_name(collection="test") # pyright: ignore[reportPrivateUsage] collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage] - - collection.insert_one({'key': 'legacy_key', 'value': {'object': {'name': 'Alice', 'age': 30}}}) - - result = store.get(collection='test', key='legacy_key') - assert result == {'name': 'Alice', 'age': 30} + + collection.insert_one({"key": "legacy_key", "value": {"object": {"name": "Alice", "age": 30}}}) + + result = store.get(collection="test", key="legacy_key") + assert result == {"name": "Alice", "age": 30} diff --git a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py index a7dd100e..4620fba8 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py +++ b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py @@ -19,26 +19,28 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin # Redis test configuration -REDIS_HOST = 'localhost' +REDIS_HOST = "localhost" REDIS_PORT = 6379 REDIS_DB = 15 # Use a separate database for tests WAIT_FOR_REDIS_TIMEOUT = 30 -REDIS_VERSIONS_TO_TEST = ['4.0.0', '7.0.0'] +REDIS_VERSIONS_TO_TEST = ["4.0.0", "7.0.0"] def test_managed_entry_document_conversion(): created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={'test': 'test'}, created_at=created_at, expires_at=expires_at) + + managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) document = managed_entry_to_json(managed_entry=managed_entry) - - assert document == snapshot('{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}') - + + assert document == snapshot( + '{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}' + ) + round_trip_managed_entry = json_to_managed_entry(json_str=document) - + assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at assert round_trip_managed_entry.ttl == IsFloat(lt=0) @@ -57,20 +59,18 @@ class RedisFailedToStartError(Exception): pass -@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestRedisStore(ContextManagerStoreTestMixin, BaseStoreTests): - - @pytest.fixture(autouse=True, scope='session', params=REDIS_VERSIONS_TO_TEST) + @pytest.fixture(autouse=True, scope="session", params=REDIS_VERSIONS_TO_TEST) def setup_redis(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container('redis-test', f'redis:{version}', {'6379': REDIS_PORT}): + + with docker_container("redis-test", f"redis:{version}", {"6379": REDIS_PORT}): if not wait_for_true(bool_fn=ping_redis, tries=30, wait_time=1): - msg = 'Redis failed to start' + msg = "Redis failed to start" raise RedisFailedToStartError(msg) - + yield - @override @pytest.fixture @@ -80,32 +80,28 @@ def store(self, setup_redis: RedisStore) -> RedisStore: redis_store = RedisStore(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) _ = redis_store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] return redis_store - def test_redis_url_connection(self): """Test Redis store creation with URL.""" - redis_url = f'redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}' + redis_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" store = RedisStore(url=redis_url) _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] - store.put(collection='test', key='url_test', value={'test': 'value'}) - result = store.get(collection='test', key='url_test') - assert result == {'test': 'value'} - + store.put(collection="test", key="url_test", value={"test": "value"}) + result = store.get(collection="test", key="url_test") + assert result == {"test": "value"} def test_redis_client_connection(self): """Test Redis store creation with existing client.""" from redis import Redis - + client = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) store = RedisStore(client=client) - + _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] - store.put(collection='test', key='client_test', value={'test': 'value'}) - result = store.get(collection='test', key='client_test') - assert result == {'test': 'value'} - + store.put(collection="test", key="client_test", value={"test": "value"}) + result = store.get(collection="test", key="client_test") + assert result == {"test": "value"} - @pytest.mark.skip(reason='Distributed Caches are unbounded') + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override - def test_not_unbounded(self, store: BaseStore): - ... + def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py index 29ab08bd..2e04cd08 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py @@ -14,56 +14,51 @@ class TestRocksDBStore(ContextManagerStoreTestMixin, BaseStoreTests): - @override @pytest.fixture def store(self) -> Generator[RocksDBStore, None, None]: """Create a RocksDB store for testing.""" # Create a temporary directory for the RocksDB database with TemporaryDirectory() as temp_dir: - db_path = Path(temp_dir) / 'test_db' + db_path = Path(temp_dir) / "test_db" rocksdb_store = RocksDBStore(path=db_path) yield rocksdb_store - def test_rocksdb_path_connection(self): """Test RocksDB store creation with path.""" temp_dir = TemporaryDirectory() - db_path = Path(temp_dir.name) / 'path_test_db' - + db_path = Path(temp_dir.name) / "path_test_db" + store = RocksDBStore(path=db_path) - - store.put(collection='test', key='path_test', value={'test': 'value'}) - result = store.get(collection='test', key='path_test') - assert result == {'test': 'value'} - + + store.put(collection="test", key="path_test", value={"test": "value"}) + result = store.get(collection="test", key="path_test") + assert result == {"test": "value"} + store.close() temp_dir.cleanup() - def test_rocksdb_db_connection(self): """Test RocksDB store creation with existing DB instance.""" from rocksdict import Options, Rdict - + temp_dir = TemporaryDirectory() - db_path = Path(temp_dir.name) / 'db_test_db' + db_path = Path(temp_dir.name) / "db_test_db" db_path.mkdir(parents=True, exist_ok=True) - + opts = Options() opts.create_if_missing(True) db = Rdict(str(db_path), options=opts) - + store = RocksDBStore(db=db) - - store.put(collection='test', key='db_test', value={'test': 'value'}) - result = store.get(collection='test', key='db_test') - assert result == {'test': 'value'} - + + store.put(collection="test", key="db_test", value={"test": "value"}) + result = store.get(collection="test", key="db_test") + assert result == {"test": "value"} + store.close() temp_dir.cleanup() - - @pytest.mark.skip(reason='Local disk stores are unbounded') + @pytest.mark.skip(reason="Local disk stores are unbounded") @override - def test_not_unbounded(self, store: BaseStore): - ... + def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py b/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py index 724cb462..1ee92614 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py +++ b/key-value/key-value-sync/tests/code_gen/stores/simple/test_store.py @@ -9,7 +9,6 @@ class TestSimpleStore(BaseStoreTests): - @override @pytest.fixture def store(self) -> SimpleStore: diff --git a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py index 720d0293..69349194 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py +++ b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py @@ -13,33 +13,33 @@ from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin # Valkey test configuration -VALKEY_HOST = 'localhost' +VALKEY_HOST = "localhost" VALKEY_PORT = 6380 # normally 6379, avoid clashing with Redis tests VALKEY_DB = 15 VALKEY_CONTAINER_PORT = 6379 WAIT_FOR_VALKEY_TIMEOUT = 30 - # Released Apr 2024 +# Released Apr 2024 # Released Sep 2024 # Released Oct 2025 -VALKEY_VERSIONS_TO_TEST = ['7.2.5', '8.0.0', '9.0.0'] +VALKEY_VERSIONS_TO_TEST = ["7.2.5", "8.0.0", "9.0.0"] class ValkeyFailedToStartError(Exception): pass -@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') -@pytest.mark.skipif(detect_on_windows(), reason='Valkey is not supported on Windows') +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") +@pytest.mark.skipif(detect_on_windows(), reason="Valkey is not supported on Windows") class TestValkeyStore(ContextManagerStoreTestMixin, BaseStoreTests): - def get_valkey_client(self): - from glide_sync.glide_client import GlideClient from glide_sync.config import GlideClientConfiguration, NodeAddress - - client_config: GlideClientConfiguration = GlideClientConfiguration(addresses=[NodeAddress(host=VALKEY_HOST, port=VALKEY_PORT)], database_id=VALKEY_DB) + from glide_sync.glide_client import GlideClient + + client_config: GlideClientConfiguration = GlideClientConfiguration( + addresses=[NodeAddress(host=VALKEY_HOST, port=VALKEY_PORT)], database_id=VALKEY_DB + ) return GlideClient.create(config=client_config) - def ping_valkey(self) -> bool: client = None @@ -54,35 +54,31 @@ def ping_valkey(self) -> bool: if client is not None: with contextlib.suppress(Exception): client.close() - - @pytest.fixture(scope='session', params=VALKEY_VERSIONS_TO_TEST) + @pytest.fixture(scope="session", params=VALKEY_VERSIONS_TO_TEST) def setup_valkey(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container(f'valkey-test-{version}', f'valkey/valkey:{version}', {str(VALKEY_CONTAINER_PORT): VALKEY_PORT}): + + with docker_container(f"valkey-test-{version}", f"valkey/valkey:{version}", {str(VALKEY_CONTAINER_PORT): VALKEY_PORT}): if not wait_for_true(bool_fn=self.ping_valkey, tries=WAIT_FOR_VALKEY_TIMEOUT, wait_time=1): - msg = f'Valkey {version} failed to start' + msg = f"Valkey {version} failed to start" raise ValkeyFailedToStartError(msg) - + yield - @override @pytest.fixture def store(self, setup_valkey: None): from key_value.sync.code_gen.stores.valkey import ValkeyStore - + store: ValkeyStore = ValkeyStore(host=VALKEY_HOST, port=VALKEY_PORT, db=VALKEY_DB) - + # This is a syncronous client client = self.get_valkey_client() _ = client.flushdb() - + return store - - @pytest.mark.skip(reason='Distributed Caches are unbounded') + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override - def test_not_unbounded(self, store: BaseStore): - ... + def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py b/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py index 7d145ea6..c7a5fd74 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py +++ b/key-value/key-value-sync/tests/code_gen/stores/vault/test_vault.py @@ -12,30 +12,28 @@ from tests.code_gen.stores.base import BaseStoreTests # Vault test configuration -VAULT_HOST = 'localhost' +VAULT_HOST = "localhost" VAULT_PORT = 8200 -VAULT_TOKEN = 'dev-root-token' # noqa: S105 -VAULT_MOUNT_POINT = 'secret' +VAULT_TOKEN = "dev-root-token" # noqa: S105 +VAULT_MOUNT_POINT = "secret" VAULT_CONTAINER_PORT = 8200 WAIT_FOR_VAULT_TIMEOUT = 30 - # Released Oct 2022 +# Released Oct 2022 # Released Oct 2025 -VAULT_VERSIONS_TO_TEST = ['1.12.0', '1.21.0'] +VAULT_VERSIONS_TO_TEST = ["1.12.0", "1.21.0"] class VaultFailedToStartError(Exception): pass -@pytest.mark.skipif(should_skip_docker_tests(), reason='Docker is not running') +@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestVaultStore(BaseStoreTests): - def get_vault_client(self): import hvac - - return hvac.Client(url=f'http://{VAULT_HOST}:{VAULT_PORT}', token=VAULT_TOKEN) - + + return hvac.Client(url=f"http://{VAULT_HOST}:{VAULT_PORT}", token=VAULT_TOKEN) def ping_vault(self) -> bool: try: @@ -43,44 +41,45 @@ def ping_vault(self) -> bool: return client.sys.is_initialized() # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] except Exception: return False - - @pytest.fixture(scope='session', params=VAULT_VERSIONS_TO_TEST) + @pytest.fixture(scope="session", params=VAULT_VERSIONS_TO_TEST) def setup_vault(self, request: pytest.FixtureRequest) -> Generator[None, None, None]: version = request.param - - with docker_container(f'vault-test-{version}', f'hashicorp/vault:{version}', {str(VAULT_CONTAINER_PORT): VAULT_PORT}, environment={'VAULT_DEV_ROOT_TOKEN_ID': VAULT_TOKEN, 'VAULT_DEV_LISTEN_ADDRESS': '0.0.0.0:8200'}): + + with docker_container( + f"vault-test-{version}", + f"hashicorp/vault:{version}", + {str(VAULT_CONTAINER_PORT): VAULT_PORT}, + environment={"VAULT_DEV_ROOT_TOKEN_ID": VAULT_TOKEN, "VAULT_DEV_LISTEN_ADDRESS": "0.0.0.0:8200"}, + ): if not wait_for_true(bool_fn=self.ping_vault, tries=WAIT_FOR_VAULT_TIMEOUT, wait_time=1): - msg = f'Vault {version} failed to start' + msg = f"Vault {version} failed to start" raise VaultFailedToStartError(msg) - + yield - @override @pytest.fixture def store(self, setup_vault: None): from key_value.sync.code_gen.stores.vault import VaultStore - - store: VaultStore = VaultStore(url=f'http://{VAULT_HOST}:{VAULT_PORT}', token=VAULT_TOKEN, mount_point=VAULT_MOUNT_POINT) - + + store: VaultStore = VaultStore(url=f"http://{VAULT_HOST}:{VAULT_PORT}", token=VAULT_TOKEN, mount_point=VAULT_MOUNT_POINT) + # Clean up any existing data - best effort, ignore errors client = self.get_vault_client() try: # List all secrets and delete them - secrets_list = client.secrets.kv.v2.list_secrets(path='', mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] - if secrets_list and 'data' in secrets_list and ('keys' in secrets_list['data']): - for key in secrets_list['data']['keys']: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + secrets_list = client.secrets.kv.v2.list_secrets(path="", mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] + if secrets_list and "data" in secrets_list and ("keys" in secrets_list["data"]): + for key in secrets_list["data"]["keys"]: # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] # Best effort cleanup - ignore individual deletion failures - client.secrets.kv.v2.delete_metadata_and_all_versions(path=key.rstrip('/'), mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] + client.secrets.kv.v2.delete_metadata_and_all_versions(path=key.rstrip("/"), mount_point=VAULT_MOUNT_POINT) # pyright: ignore[reportUnknownMemberType,reportUnknownReturnType,reportUnknownVariableType] except Exception: # noqa: S110 # Cleanup is best-effort, ignore all errors pass - + return store - - @pytest.mark.skip(reason='Distributed Caches are unbounded') + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override - def test_not_unbounded(self, store: BaseStore): - ... + def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py b/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py index f90e948c..65105b16 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py +++ b/key-value/key-value-sync/tests/code_gen/stores/windows_registry/test_windows_registry.py @@ -13,31 +13,27 @@ if TYPE_CHECKING: from key_value.sync.code_gen.stores.windows_registry.store import WindowsRegistryStore -TEST_REGISTRY_PATH = 'software\\py-key-value-test' +TEST_REGISTRY_PATH = "software\\py-key-value-test" -@pytest.mark.skipif(condition=not detect_on_windows(), reason='WindowsRegistryStore is only available on Windows') +@pytest.mark.skipif(condition=not detect_on_windows(), reason="WindowsRegistryStore is only available on Windows") class TestWindowsRegistryStore(BaseStoreTests): - def cleanup(self): from winreg import HKEY_CURRENT_USER - + from key_value.sync.code_gen.stores.windows_registry.utils import delete_sub_keys - + delete_sub_keys(hive=HKEY_CURRENT_USER, sub_key=TEST_REGISTRY_PATH) - @override @pytest.fixture - def store(self) -> 'WindowsRegistryStore': + def store(self) -> "WindowsRegistryStore": from key_value.sync.code_gen.stores.windows_registry.store import WindowsRegistryStore - + self.cleanup() - - return WindowsRegistryStore(registry_path=TEST_REGISTRY_PATH, hive='HKEY_CURRENT_USER') - + + return WindowsRegistryStore(registry_path=TEST_REGISTRY_PATH, hive="HKEY_CURRENT_USER") @override - @pytest.mark.skip(reason='We do not test boundedness of registry stores') - def test_not_unbounded(self, store: BaseStore): - ... + @pytest.mark.skip(reason="We do not test boundedness of registry stores") + def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py index 72efb239..1def5384 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_compression.py @@ -10,118 +10,114 @@ class TestCompressionWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> CompressionWrapper: # Set min_size to 0 so all values get compressed for testing return CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - def test_compression_small_value_not_compressed(self, memory_store: MemoryStore): # With default min_size (1024), small values shouldn't be compressed compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=1024) - - small_value = {'test': 'value'} - compression_store.put(collection='test', key='test', value=small_value) - + + small_value = {"test": "value"} + compression_store.put(collection="test", key="test", value=small_value) + # Check the underlying store - should NOT be compressed - raw_value = memory_store.get(collection='test', key='test') + raw_value = memory_store.get(collection="test", key="test") assert raw_value is not None assert raw_value == small_value - assert '__compressed_data__' not in raw_value - + assert "__compressed_data__" not in raw_value + # Retrieve through wrapper - result = compression_store.get(collection='test', key='test') + result = compression_store.get(collection="test", key="test") assert result == small_value - def test_compression_large_value_compressed(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=100) - + # Create a large value - large_value = {'data': 'x' * 1000, 'more_data': 'y' * 1000} - compression_store.put(collection='test', key='test', value=large_value) - + large_value = {"data": "x" * 1000, "more_data": "y" * 1000} + compression_store.put(collection="test", key="test", value=large_value) + # Check the underlying store - should be compressed - raw_value = memory_store.get(collection='test', key='test') + raw_value = memory_store.get(collection="test", key="test") assert raw_value is not None - assert '__compressed_data__' in raw_value - assert '__compression_version__' in raw_value - assert isinstance(raw_value['__compressed_data__'], str) - + assert "__compressed_data__" in raw_value + assert "__compression_version__" in raw_value + assert isinstance(raw_value["__compressed_data__"], str) + # Retrieve through wrapper - should decompress automatically - result = compression_store.get(collection='test', key='test') + result = compression_store.get(collection="test", key="test") assert result == large_value - def test_compression_many_operations(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - - keys = ['k1', 'k2', 'k3'] - values = [{'data': 'value1'}, {'data': 'value2'}, {'data': 'value3'}] - - compression_store.put_many(collection='test', keys=keys, values=values) - + + keys = ["k1", "k2", "k3"] + values = [{"data": "value1"}, {"data": "value2"}, {"data": "value3"}] + + compression_store.put_many(collection="test", keys=keys, values=values) + # Check underlying store - all should be compressed for key in keys: - raw_value = memory_store.get(collection='test', key=key) + raw_value = memory_store.get(collection="test", key=key) assert raw_value is not None - assert '__compressed_data__' in raw_value - + assert "__compressed_data__" in raw_value + # Retrieve through wrapper - results = compression_store.get_many(collection='test', keys=keys) + results = compression_store.get_many(collection="test", keys=keys) assert results == values - def test_compression_already_compressed_not_recompressed(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Manually create a compressed value - compressed_value = {'__compressed_data__': 'H4sIAAAAAAAAA6tWKkktLlGyUlAqS8wpTtVRKi1OLUpVslIqLU4tUqoFAJRxMHkfAAAA', '__compression_version__': 1, '__compression_algorithm__': 'gzip'} - + compressed_value = { + "__compressed_data__": "H4sIAAAAAAAAA6tWKkktLlGyUlAqS8wpTtVRKi1OLUpVslIqLU4tUqoFAJRxMHkfAAAA", + "__compression_version__": 1, + "__compression_algorithm__": "gzip", + } + # Should not try to compress again result = compression_store._compress_value(value=compressed_value) # pyright: ignore[reportPrivateUsage] assert result == compressed_value - def test_decompression_handles_uncompressed_data(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Store uncompressed data directly in underlying store - uncompressed_value = {'test': 'value'} - memory_store.put(collection='test', key='test', value=uncompressed_value) - + uncompressed_value = {"test": "value"} + memory_store.put(collection="test", key="test", value=uncompressed_value) + # Should return as-is when retrieved through compression wrapper - result = compression_store.get(collection='test', key='test') + result = compression_store.get(collection="test", key="test") assert result == uncompressed_value - def test_decompression_handles_corrupted_data(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Store corrupted compressed data - corrupted_value = {'__compressed_data__': 'invalid-base64-data!!!', '__compression_version__': 1} - memory_store.put(collection='test', key='test', value=corrupted_value) - + corrupted_value = {"__compressed_data__": "invalid-base64-data!!!", "__compression_version__": 1} + memory_store.put(collection="test", key="test", value=corrupted_value) + # Should return the corrupted value as-is rather than crashing - result = compression_store.get(collection='test', key='test') + result = compression_store.get(collection="test", key="test") assert result == corrupted_value - def test_compression_size_reduction(self, memory_store: MemoryStore): compression_store = CompressionWrapper(key_value=memory_store, min_size_to_compress=0) - + # Create a highly compressible value (repeated data) - large_value = {'data': 'x' * 10000} - - compression_store.put(collection='test', key='test', value=large_value) - + large_value = {"data": "x" * 10000} + + compression_store.put(collection="test", key="test", value=large_value) + # Check the compressed size - raw_value = memory_store.get(collection='test', key='test') + raw_value = memory_store.get(collection="test", key="test") assert raw_value is not None - compressed_data = raw_value['__compressed_data__'] - + compressed_data = raw_value["__compressed_data__"] + # Compressed data should be significantly smaller than original # Original is ~10KB, compressed should be much smaller due to repetition assert len(compressed_data) < 1000 # Should be less than 1KB diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py index e4ee1018..40323f9f 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_default_value.py @@ -10,107 +10,85 @@ from key_value.sync.code_gen.wrappers.default_value import DefaultValueWrapper from tests.code_gen.stores.base import BaseStoreTests -TEST_KEY_1 = 'test_key_1' -TEST_KEY_2 = 'test_key_2' -TEST_COLLECTION = 'test_collection' -DEFAULT_VALUE = {'obj_key': 'obj_value'} +TEST_KEY_1 = "test_key_1" +TEST_KEY_2 = "test_key_2" +TEST_COLLECTION = "test_collection" +DEFAULT_VALUE = {"obj_key": "obj_value"} DEFAULT_TTL = 100 class TestDefaultValueWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> DefaultValueWrapper: return DefaultValueWrapper(key_value=memory_store, default_value=DEFAULT_VALUE, default_ttl=DEFAULT_TTL) - def test_default_value(self, store: BaseStore): assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_1) == DEFAULT_VALUE assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_1) == (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)) assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, DEFAULT_VALUE] - assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [(DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL))] - - store.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value={'key_2': 'value_2'}, ttl=200) - assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_2) == {'key_2': 'value_2'} - assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_2) == ({'key_2': 'value_2'}, IsFloat(approx=200)) - assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, {'key_2': 'value_2'}] - assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [(DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), ({'key_2': 'value_2'}, IsFloat(approx=200))] - + assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + ] + + store.put(collection=TEST_COLLECTION, key=TEST_KEY_2, value={"key_2": "value_2"}, ttl=200) + assert store.get(collection=TEST_COLLECTION, key=TEST_KEY_2) == {"key_2": "value_2"} + assert store.ttl(collection=TEST_COLLECTION, key=TEST_KEY_2) == ({"key_2": "value_2"}, IsFloat(approx=200)) + assert store.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [DEFAULT_VALUE, {"key_2": "value_2"}] + assert store.ttl_many(collection=TEST_COLLECTION, keys=[TEST_KEY_1, TEST_KEY_2]) == [ + (DEFAULT_VALUE, IsFloat(approx=DEFAULT_TTL)), + ({"key_2": "value_2"}, IsFloat(approx=200)), + ] @override @pytest.mark.skip - def test_empty_get(self, store: BaseStore): - ... - + def test_empty_get(self, store: BaseStore): ... @override @pytest.mark.skip - def test_put_put_get_many_missing_one(self, store: BaseStore): - ... - + def test_put_put_get_many_missing_one(self, store: BaseStore): ... @override @pytest.mark.skip - def test_empty_ttl(self, store: BaseStore): - ... - + def test_empty_ttl(self, store: BaseStore): ... @override @pytest.mark.skip - def test_get_put_get(self, store: BaseStore): - ... - + def test_get_put_get(self, store: BaseStore): ... @override @pytest.mark.skip - def test_get_put_get_delete_get(self, store: BaseStore): - ... - + def test_get_put_get_delete_get(self, store: BaseStore): ... @override @pytest.mark.skip - def test_put_get_delete_get(self, store: BaseStore): - ... - + def test_put_get_delete_get(self, store: BaseStore): ... @override @pytest.mark.skip - def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): - ... - + def test_put_many_get_get_delete_many_get_many(self, store: BaseStore): ... @override @pytest.mark.skip - def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): - ... - + def test_put_many_get_many_delete_many_get_many(self, store: BaseStore): ... @override @pytest.mark.skip - def test_get_put_get_put_delete_get(self, store: BaseStore): - ... - + def test_get_put_get_put_delete_get(self, store: BaseStore): ... @override @pytest.mark.skip - def test_put_many_delete_delete_get_many(self, store: BaseStore): - ... - + def test_put_many_delete_delete_get_many(self, store: BaseStore): ... @override @pytest.mark.skip - def test_put_expired_get_none(self, store: BaseStore): - ... - + def test_put_expired_get_none(self, store: BaseStore): ... @override @pytest.mark.skip - def test_not_unbounded(self, store: BaseStore): - ... - + def test_not_unbounded(self, store: BaseStore): ... @override @pytest.mark.skip - def test_concurrent_operations(self, store: BaseStore): - ... + def test_concurrent_operations(self, store: BaseStore): ... diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py index a64d2500..e20cb7c8 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_encryption.py @@ -20,160 +20,150 @@ def fernet() -> Fernet: class TestFernetEncryptionWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore, fernet: Fernet) -> FernetEncryptionWrapper: return FernetEncryptionWrapper(key_value=memory_store, fernet=fernet) - def test_encryption_encrypts_value(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that values are actually encrypted in the underlying store.""" - original_value = {'test': 'value', 'number': 123} - store.put(collection='test', key='test', value=original_value) - + original_value = {"test": "value", "number": 123} + store.put(collection="test", key="test", value=original_value) + # Check the underlying store - should be encrypted - raw_value = memory_store.get(collection='test', key='test') + raw_value = memory_store.get(collection="test", key="test") assert raw_value is not None - assert '__encrypted_data__' in raw_value - assert '__encryption_version__' in raw_value - assert isinstance(raw_value['__encrypted_data__'], str) - + assert "__encrypted_data__" in raw_value + assert "__encryption_version__" in raw_value + assert isinstance(raw_value["__encrypted_data__"], str) + # The encrypted data should not contain the original value - assert 'test' not in str(raw_value) - assert 'value' not in str(raw_value) - + assert "test" not in str(raw_value) + assert "value" not in str(raw_value) + # Retrieve through wrapper - should decrypt automatically - result = store.get(collection='test', key='test') + result = store.get(collection="test", key="test") assert result == original_value - def test_encryption_with_wrong_encryption_version(self, store: FernetEncryptionWrapper): """Test that encryption fails with the wrong encryption version.""" store.encryption_version = 2 - original_value = {'test': 'value'} - store.put(collection='test', key='test', value=original_value) - - assert store.get(collection='test', key='test') is not None + original_value = {"test": "value"} + store.put(collection="test", key="test", value=original_value) + + assert store.get(collection="test", key="test") is not None store.encryption_version = 1 - + with pytest.raises(DecryptionError): - store.get(collection='test', key='test') - + store.get(collection="test", key="test") def test_encryption_with_string_key(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that encryption works with a string key.""" - original_value = {'test': 'value'} - store.put(collection='test', key='test', value=original_value) - - round_trip_value = store.get(collection='test', key='test') + original_value = {"test": "value"} + store.put(collection="test", key="test", value=original_value) + + round_trip_value = store.get(collection="test", key="test") assert round_trip_value == original_value - - raw_result = memory_store.get(collection='test', key='test') - assert raw_result == snapshot({'__encrypted_data__': IsStr(min_length=32), '__encryption_version__': 1}) - + + raw_result = memory_store.get(collection="test", key="test") + assert raw_result == snapshot({"__encrypted_data__": IsStr(min_length=32), "__encryption_version__": 1}) def test_encryption_many_operations(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that encryption works with put_many and get_many.""" - keys = ['k1', 'k2', 'k3'] - values = [{'data': 'value1'}, {'data': 'value2'}, {'data': 'value3'}] - - store.put_many(collection='test', keys=keys, values=values) - + keys = ["k1", "k2", "k3"] + values = [{"data": "value1"}, {"data": "value2"}, {"data": "value3"}] + + store.put_many(collection="test", keys=keys, values=values) + # Check underlying store - all should be encrypted for key in keys: - raw_value = memory_store.get(collection='test', key=key) + raw_value = memory_store.get(collection="test", key=key) assert raw_value is not None - assert '__encrypted_data__' in raw_value - + assert "__encrypted_data__" in raw_value + # Retrieve through wrapper - results = store.get_many(collection='test', keys=keys) + results = store.get_many(collection="test", keys=keys) assert results == values - def test_decryption_handles_unencrypted_data(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that unencrypted data is returned as-is.""" # Store unencrypted data directly in underlying store - unencrypted_value = {'test': 'value'} - memory_store.put(collection='test', key='test', value=unencrypted_value) - + unencrypted_value = {"test": "value"} + memory_store.put(collection="test", key="test", value=unencrypted_value) + # Should return as-is when retrieved through encryption wrapper - result = store.get(collection='test', key='test') + result = store.get(collection="test", key="test") assert result == unencrypted_value - def test_decryption_handles_corrupted_data(self, store: FernetEncryptionWrapper, memory_store: MemoryStore): """Test that corrupted encrypted data is handled gracefully.""" - + # Store corrupted encrypted data - corrupted_value = {'__encrypted_data__': 'invalid-encrypted-data!!!', '__encryption_version__': 1} - memory_store.put(collection='test', key='test', value=corrupted_value) - + corrupted_value = {"__encrypted_data__": "invalid-encrypted-data!!!", "__encryption_version__": 1} + memory_store.put(collection="test", key="test", value=corrupted_value) + with pytest.raises(DecryptionError): - store.get(collection='test', key='test') - + store.get(collection="test", key="test") def test_decryption_ignores_corrupted_data(self, memory_store: MemoryStore, fernet: Fernet): """Test that corrupted encrypted data is ignored.""" store = FernetEncryptionWrapper(key_value=memory_store, fernet=fernet, raise_on_decryption_error=False) - + # Store corrupted encrypted data - corrupted_value = {'__encrypted_data__': 'invalid-encrypted-data!!!', '__encryption_version__': 1} - memory_store.put(collection='test', key='test', value=corrupted_value) - - assert store.get(collection='test', key='test') is None - + corrupted_value = {"__encrypted_data__": "invalid-encrypted-data!!!", "__encryption_version__": 1} + memory_store.put(collection="test", key="test", value=corrupted_value) + + assert store.get(collection="test", key="test") is None def test_decryption_with_multi_fernet(self, memory_store: MemoryStore): """Test that decryption works with a MultiFernet.""" first_fernet = Fernet(key=Fernet.generate_key()) first_fernet_store = FernetEncryptionWrapper(key_value=memory_store, fernet=first_fernet) - original_value = {'test': 'value'} - first_fernet_store.put(collection='test', key='test', value=original_value) - assert first_fernet_store.get(collection='test', key='test') == original_value - + original_value = {"test": "value"} + first_fernet_store.put(collection="test", key="test", value=original_value) + assert first_fernet_store.get(collection="test", key="test") == original_value + second_fernet = Fernet(key=Fernet.generate_key()) multi_fernet = MultiFernet([second_fernet, first_fernet]) multi_fernet_store = FernetEncryptionWrapper(key_value=memory_store, fernet=multi_fernet) - assert multi_fernet_store.get(collection='test', key='test') == original_value - + assert multi_fernet_store.get(collection="test", key="test") == original_value def test_decryption_with_wrong_key_raises_error(self, memory_store: MemoryStore): """Test that decryption with the wrong key raises an error.""" fernet1 = Fernet(key=Fernet.generate_key()) fernet2 = Fernet(key=Fernet.generate_key()) - + store1 = FernetEncryptionWrapper(key_value=memory_store, fernet=fernet1) store2 = FernetEncryptionWrapper(key_value=memory_store, fernet=fernet2) - - original_value = {'test': 'value'} - store1.put(collection='test', key='test', value=original_value) - + + original_value = {"test": "value"} + store1.put(collection="test", key="test", value=original_value) + with pytest.raises(DecryptionError): - store2.get(collection='test', key='test') + store2.get(collection="test", key="test") def test_key_generation(): """Test that key generation works with a source material and salt and that different source materials and salts produce different keys.""" - - source_material = 'test-source-material' - salt = 'test-salt' + + source_material = "test-source-material" + salt = "test-salt" key = _generate_encryption_key(source_material=source_material, salt=salt) key_str_one = key.decode() - assert key_str_one == snapshot('znx7rVYt4roVgu3ymt5sIYFmfMNGEPbm8AShXQv6CY4=') - - source_material = 'different-source-material' - salt = 'test-salt' + assert key_str_one == snapshot("znx7rVYt4roVgu3ymt5sIYFmfMNGEPbm8AShXQv6CY4=") + + source_material = "different-source-material" + salt = "test-salt" key = _generate_encryption_key(source_material=source_material, salt=salt) key_str_two = key.decode() - assert key_str_two == snapshot('1TLRpjxQm4Op699i9hAXFVfyz6PqPXbuvwKaWB48tS8=') - - source_material = 'test-source-material' - salt = 'different-salt' + assert key_str_two == snapshot("1TLRpjxQm4Op699i9hAXFVfyz6PqPXbuvwKaWB48tS8=") + + source_material = "test-source-material" + salt = "different-salt" key = _generate_encryption_key(source_material=source_material, salt=salt) key_str_three = key.decode() - assert key_str_three == snapshot('oLz_g5NoLhANNh2_-ZwbgchDZ1q23VFx90kUQDjracc=') - + assert key_str_three == snapshot("oLz_g5NoLhANNh2_-ZwbgchDZ1q23VFx90kUQDjracc=") + assert key_str_one != key_str_two assert key_str_one != key_str_three assert key_str_two != key_str_three diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py index 6a9f4b27..e35c1804 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_fallback.py @@ -14,76 +14,69 @@ class FailingStore(MemoryStore): """A store that always fails.""" - @override - def get(self, key: str, *, collection: str | None=None) -> dict[str, Any] | None: - msg = 'Primary store unavailable' + def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None: + msg = "Primary store unavailable" raise ConnectionError(msg) - @override - def put(self, key: str, value: Mapping[str, Any], *, collection: str | None=None, ttl: SupportsFloat | None=None): - msg = 'Primary store unavailable' + def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None): + msg = "Primary store unavailable" raise ConnectionError(msg) class TestFallbackWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> FallbackWrapper: fallback_store = MemoryStore() return FallbackWrapper(primary_key_value=memory_store, fallback_key_value=fallback_store) - def test_fallback_on_primary_failure(self): primary_store = FailingStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store) - + # Put data in fallback store directly - fallback_store.put(collection='test', key='test', value={'test': 'fallback_value'}) - + fallback_store.put(collection="test", key="test", value={"test": "fallback_value"}) + # Should fall back to secondary store - result = wrapper.get(collection='test', key='test') - assert result == {'test': 'fallback_value'} - + result = wrapper.get(collection="test", key="test") + assert result == {"test": "fallback_value"} def test_primary_success_no_fallback(self): primary_store = MemoryStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store) - + # Put data in primary store - primary_store.put(collection='test', key='test', value={'test': 'primary_value'}) - + primary_store.put(collection="test", key="test", value={"test": "primary_value"}) + # Put different data in fallback store - fallback_store.put(collection='test', key='test', value={'test': 'fallback_value'}) - + fallback_store.put(collection="test", key="test", value={"test": "fallback_value"}) + # Should use primary store - result = wrapper.get(collection='test', key='test') - assert result == {'test': 'primary_value'} - + result = wrapper.get(collection="test", key="test") + assert result == {"test": "primary_value"} def test_write_to_fallback_disabled(self): primary_store = FailingStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store, write_to_fallback=False) - + # Writes should fail without falling back with pytest.raises(ConnectionError): - wrapper.put(collection='test', key='test', value={'test': 'value'}) - + wrapper.put(collection="test", key="test", value={"test": "value"}) def test_write_to_fallback_enabled(self): primary_store = FailingStore() fallback_store = MemoryStore() wrapper = FallbackWrapper(primary_key_value=primary_store, fallback_key_value=fallback_store, write_to_fallback=True) - + # Writes should fall back to secondary - wrapper.put(collection='test', key='test', value={'test': 'value'}) - + wrapper.put(collection="test", key="test", value={"test": "value"}) + # Verify it was written to fallback - result = fallback_store.get(collection='test', key='test') - assert result == {'test': 'value'} + result = fallback_store.get(collection="test", key="test") + assert result == {"test": "value"} diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py index 0cd0d77f..e404249a 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py @@ -11,160 +11,153 @@ class TestLimitSizeWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> LimitSizeWrapper: # Set a reasonable max size for normal test operations (20KB to handle large test strings) return LimitSizeWrapper(key_value=memory_store, max_size=20 * 1024, raise_on_too_large=False, raise_on_too_small=False) - def test_put_within_limit(self, memory_store: MemoryStore): - limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=1024, raise_on_too_large=True, raise_on_too_small=False) - + limit_size_store: LimitSizeWrapper = LimitSizeWrapper( + key_value=memory_store, max_size=1024, raise_on_too_large=True, raise_on_too_small=False + ) + # Small value should succeed - limit_size_store.put(collection='test', key='test', value={'test': 'test'}) - result = limit_size_store.get(collection='test', key='test') + limit_size_store.put(collection="test", key="test", value={"test": "test"}) + result = limit_size_store.get(collection="test", key="test") assert result is not None - assert result['test'] == 'test' - + assert result["test"] == "test" def test_put_exceeds_limit_with_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=True) - + # Large value should raise an error - large_value = {'data': 'x' * 1000} + large_value = {"data": "x" * 1000} with pytest.raises(EntryTooLargeError): - limit_size_store.put(collection='test', key='test', value=large_value) - + limit_size_store.put(collection="test", key="test", value=large_value) + # Verify nothing was stored - result = limit_size_store.get(collection='test', key='test') + result = limit_size_store.get(collection="test", key="test") assert result is None - def test_put_exceeds_limit_without_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=False) - + # Large value should be silently ignored - large_value = {'data': 'x' * 1000} - limit_size_store.put(collection='test', key='test', value=large_value) - + large_value = {"data": "x" * 1000} + limit_size_store.put(collection="test", key="test", value=large_value) + # Verify nothing was stored - result = limit_size_store.get(collection='test', key='test') + result = limit_size_store.get(collection="test", key="test") assert result is None - def test_put_below_min_size_with_raise_on_too_small(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, min_size=100, raise_on_too_small=True) - + # Small value should raise an error - small_value = {'data': 'x'} + small_value = {"data": "x"} with pytest.raises(EntryTooSmallError): - limit_size_store.put(collection='test', key='test', value=small_value) - + limit_size_store.put(collection="test", key="test", value=small_value) + # Verify nothing was stored - result = limit_size_store.get(collection='test', key='test') + result = limit_size_store.get(collection="test", key="test") assert result is None - def test_put_below_min_size_without_raise_on_too_small(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, min_size=100, raise_on_too_small=False) - + # Small value should be silently ignored - small_value = {'data': 'x'} - limit_size_store.put(collection='test', key='test', value=small_value) - + small_value = {"data": "x"} + limit_size_store.put(collection="test", key="test", value=small_value) + # Verify nothing was stored - result = limit_size_store.get(collection='test', key='test') + result = limit_size_store.get(collection="test", key="test") assert result is None - def test_put_many_mixed_sizes_with_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=True) - + # Mix of small and large values - keys = ['small1', 'large1', 'small2'] - values = [{'data': 'x'}, {'data': 'x' * 1000}, {'data': 'y'}] - + keys = ["small1", "large1", "small2"] + values = [{"data": "x"}, {"data": "x" * 1000}, {"data": "y"}] + # Should raise on the large value with pytest.raises(EntryTooLargeError): - limit_size_store.put_many(collection='test', keys=keys, values=values) - + limit_size_store.put_many(collection="test", keys=keys, values=values) + # Verify nothing was stored due to the error - results = limit_size_store.get_many(collection='test', keys=keys) + results = limit_size_store.get_many(collection="test", keys=keys) assert results[0] is None assert results[1] is None assert results[2] is None - def test_put_many_mixed_sizes_without_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=False) - + # Mix of small and large values - keys = ['small1', 'large1', 'small2'] - values = [{'data': 'x'}, {'data': 'x' * 1000}, {'data': 'y'}] - + keys = ["small1", "large1", "small2"] + values = [{"data": "x"}, {"data": "x" * 1000}, {"data": "y"}] + # Should silently filter out large value - limit_size_store.put_many(collection='test', keys=keys, values=values) - + limit_size_store.put_many(collection="test", keys=keys, values=values) + # Verify only small values were stored - results = limit_size_store.get_many(collection='test', keys=keys) - assert results[0] == {'data': 'x'} + results = limit_size_store.get_many(collection="test", keys=keys) + assert results[0] == {"data": "x"} assert results[1] is None # Large value was filtered out - assert results[2] == {'data': 'y'} - + assert results[2] == {"data": "y"} def test_put_many_with_ttl_sequence(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=100, raise_on_too_large=False) - + # Mix of small and large values with single TTL - keys = ['small1', 'large1', 'small2'] - values = [{'data': 'x'}, {'data': 'x' * 1000}, {'data': 'y'}] - + keys = ["small1", "large1", "small2"] + values = [{"data": "x"}, {"data": "x" * 1000}, {"data": "y"}] + # Should filter out large value - limit_size_store.put_many(collection='test', keys=keys, values=values, ttl=100) - + limit_size_store.put_many(collection="test", keys=keys, values=values, ttl=100) + # Verify only small values were stored - results = limit_size_store.get_many(collection='test', keys=keys) - assert results[0] == {'data': 'x'} + results = limit_size_store.get_many(collection="test", keys=keys) + assert results[0] == {"data": "x"} assert results[1] is None # Large value was filtered out - assert results[2] == {'data': 'y'} - + assert results[2] == {"data": "y"} def test_put_many_all_too_large_without_raise(self, memory_store: MemoryStore): limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=10, raise_on_too_large=False) - + # All values too large - keys = ['key1', 'key2'] - values = [{'data': 'x' * 1000}, {'data': 'y' * 1000}] - + keys = ["key1", "key2"] + values = [{"data": "x" * 1000}, {"data": "y" * 1000}] + # Should not raise, but nothing should be stored - limit_size_store.put_many(collection='test', keys=keys, values=values) - + limit_size_store.put_many(collection="test", keys=keys, values=values) + # Verify nothing was stored - results = limit_size_store.get_many(collection='test', keys=keys) + results = limit_size_store.get_many(collection="test", keys=keys) assert results[0] is None assert results[1] is None - def test_exact_size_limit(self, memory_store: MemoryStore): # First, determine the exact size of a small value from key_value.shared.utils.managed_entry import ManagedEntry - - test_value = {'test': 'value'} + + test_value = {"test": "value"} managed_entry = ManagedEntry(value=test_value) json_str = managed_entry.to_json() - exact_size = len(json_str.encode('utf-8')) - + exact_size = len(json_str.encode("utf-8")) + # Create a store with exact size limit limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size, raise_on_too_large=True) - + # Should succeed at exact limit - limit_size_store.put(collection='test', key='test', value=test_value) - result = limit_size_store.get(collection='test', key='test') + limit_size_store.put(collection="test", key="test", value=test_value) + result = limit_size_store.get(collection="test", key="test") assert result == test_value - + # Should fail if one byte over - limit_size_store_under: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size - 1, raise_on_too_large=True) + limit_size_store_under: LimitSizeWrapper = LimitSizeWrapper( + key_value=memory_store, max_size=exact_size - 1, raise_on_too_large=True + ) with pytest.raises(EntryTooLargeError): - limit_size_store_under.put(collection='test', key='test2', value=test_value) + limit_size_store_under.put(collection="test", key="test2", value=test_value) diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py index 74961355..2acf806a 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_logging.py @@ -20,102 +20,184 @@ def get_messages_from_caplog(caplog: pytest.LogCaptureFixture) -> list[str]: class TestLoggingWrapper(BaseStoreTests): - @override @pytest.fixture def store(self) -> LoggingWrapper: return LoggingWrapper(key_value=MemoryStore(max_entries_per_collection=500), log_level=logging.INFO) - @override @pytest.fixture def structured_logs_store(self) -> LoggingWrapper: return LoggingWrapper(key_value=MemoryStore(max_entries_per_collection=500), log_level=logging.INFO, structured_logs=True) - @pytest.fixture def capture_logs(self, caplog: pytest.LogCaptureFixture) -> Generator[LogCaptureFixture, Any, Any]: with caplog.at_level(logging.INFO): yield caplog - def test_logging_get_operations(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): - store.get(collection='test', key='test') - assert get_messages_from_caplog(capture_logs) == snapshot(["Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' ({'hit': False})"]) - + store.get(collection="test", key="test") + assert get_messages_from_caplog(capture_logs) == snapshot( + ["Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' ({'hit': False})"] + ) + capture_logs.clear() - - structured_logs_store.get(collection='test', key='test') - assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}']) - + + structured_logs_store.get(collection="test", key="test") + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', + '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}', + ] + ) + capture_logs.clear() - - store.get_many(collection='test', keys=['test', 'test_2']) - assert get_messages_from_caplog(capture_logs) == snapshot(["Start GET_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", "Finish GET_MANY collection='test' keys='['test', 'test_2']' ({'hits': 0, 'misses': 2})"]) - + + store.get_many(collection="test", keys=["test", "test_2"]) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + "Start GET_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", + "Finish GET_MANY collection='test' keys='['test', 'test_2']' ({'hits': 0, 'misses': 2})", + ] + ) + capture_logs.clear() - - structured_logs_store.get_many(collection='test', keys=['test', 'test_2']) - assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', '{"status": "finish", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"hits": 0, "misses": 2}}']) - + + structured_logs_store.get_many(collection="test", keys=["test", "test_2"]) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + '{"status": "start", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', + '{"status": "finish", "action": "GET_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"hits": 0, "misses": 2}}', + ] + ) def test_logging_put_operations(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): logging_store = LoggingWrapper(key_value=store, log_level=logging.INFO) - - logging_store.put(collection='test', key='test', value={'test': 'value'}) - assert get_messages_from_caplog(capture_logs) == snapshot(["Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})"]) - + + logging_store.put(collection="test", key="test", value={"test": "value"}) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", + "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", + "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", + "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", + ] + ) + capture_logs.clear() - - structured_logs_store.put(collection='test', key='test', value={'test': 'value'}) - assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}']) - + + structured_logs_store.put(collection="test", key="test", value={"test": "value"}) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + '{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', + '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', + ] + ) + capture_logs.clear() - - logging_store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'value'}, {'test': 'value_2'}]) - assert get_messages_from_caplog(capture_logs) == snapshot(["Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", "Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})"]) - + + logging_store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "value"}, {"test": "value_2"}]) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + "Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", + "Start PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", + "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", + "Finish PUT_MANY collection='test' keys='['test', 'test_2']' value=[{'test': 'value'}, {'test': 'value_2'}] ({'ttl': None})", + ] + ) + capture_logs.clear() - - structured_logs_store.put_many(collection='test', keys=['test', 'test_2'], values=[{'test': 'value'}, {'test': 'value_2'}]) - assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}', '{"status": "finish", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}']) - + + structured_logs_store.put_many(collection="test", keys=["test", "test_2"], values=[{"test": "value"}, {"test": "value_2"}]) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + '{"status": "start", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}', + '{"status": "finish", "action": "PUT_MANY", "collection": "test", "keys": ["test", "test_2"], "value": [{"test": "value"}, {"test": "value_2"}], "extra": {"ttl": null}}', + ] + ) def test_logging_delete_operations(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): logging_store = LoggingWrapper(key_value=store, log_level=logging.INFO) - - logging_store.delete(collection='test', key='test') - assert get_messages_from_caplog(capture_logs) == snapshot(["Start DELETE collection='test' keys='test'", "Start DELETE collection='test' keys='test'", "Finish DELETE collection='test' keys='test' ({'deleted': False})", "Finish DELETE collection='test' keys='test' ({'deleted': False})"]) - + + logging_store.delete(collection="test", key="test") + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + "Start DELETE collection='test' keys='test'", + "Start DELETE collection='test' keys='test'", + "Finish DELETE collection='test' keys='test' ({'deleted': False})", + "Finish DELETE collection='test' keys='test' ({'deleted': False})", + ] + ) + capture_logs.clear() - - structured_logs_store.delete(collection='test', key='test') - assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": false}}']) - + + structured_logs_store.delete(collection="test", key="test") + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + '{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', + '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": false}}', + ] + ) + capture_logs.clear() - - logging_store.delete_many(collection='test', keys=['test', 'test_2']) - assert get_messages_from_caplog(capture_logs) == snapshot(["Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", "Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})", "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})"]) - + + logging_store.delete_many(collection="test", keys=["test", "test_2"]) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + "Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", + "Start DELETE_MANY collection='test' keys='['test', 'test_2']' ({'keys': ['test', 'test_2']})", + "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})", + "Finish DELETE_MANY collection='test' keys='['test', 'test_2']' ({'deleted': 0})", + ] + ) + capture_logs.clear() - - structured_logs_store.delete_many(collection='test', keys=['test', 'test_2']) - assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', '{"status": "finish", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"deleted": 0}}']) - - - def test_put_get_delete_get_logging(self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture): - store.put(collection='test', key='test', value={'test': 'value'}) - assert store.get(collection='test', key='test') == {'test': 'value'} - assert store.delete(collection='test', key='test') - assert store.get(collection='test', key='test') is None - - assert get_messages_from_caplog(capture_logs) == snapshot(["Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", "Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' value={'test': 'value'} ({'hit': True})", "Start DELETE collection='test' keys='test'", "Finish DELETE collection='test' keys='test' ({'deleted': True})", "Start GET collection='test' keys='test'", "Finish GET collection='test' keys='test' ({'hit': False})"]) - + + structured_logs_store.delete_many(collection="test", keys=["test", "test_2"]) + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + '{"status": "start", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"keys": ["test", "test_2"]}}', + '{"status": "finish", "action": "DELETE_MANY", "collection": "test", "keys": ["test", "test_2"], "extra": {"deleted": 0}}', + ] + ) + + def test_put_get_delete_get_logging( + self, store: LoggingWrapper, structured_logs_store: LoggingWrapper, capture_logs: LogCaptureFixture + ): + store.put(collection="test", key="test", value={"test": "value"}) + assert store.get(collection="test", key="test") == {"test": "value"} + assert store.delete(collection="test", key="test") + assert store.get(collection="test", key="test") is None + + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + "Start PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", + "Finish PUT collection='test' keys='test' value={'test': 'value'} ({'ttl': None})", + "Start GET collection='test' keys='test'", + "Finish GET collection='test' keys='test' value={'test': 'value'} ({'hit': True})", + "Start DELETE collection='test' keys='test'", + "Finish DELETE collection='test' keys='test' ({'deleted': True})", + "Start GET collection='test' keys='test'", + "Finish GET collection='test' keys='test' ({'hit': False})", + ] + ) + capture_logs.clear() - - structured_logs_store.put(collection='test', key='test', value={'test': 'value'}) - assert structured_logs_store.get(collection='test', key='test') == {'test': 'value'} - assert structured_logs_store.delete(collection='test', key='test') - assert structured_logs_store.get(collection='test', key='test') is None - - assert get_messages_from_caplog(capture_logs) == snapshot(['{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"hit": true}}', '{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": true}}', '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}']) + + structured_logs_store.put(collection="test", key="test", value={"test": "value"}) + assert structured_logs_store.get(collection="test", key="test") == {"test": "value"} + assert structured_logs_store.delete(collection="test", key="test") + assert structured_logs_store.get(collection="test", key="test") is None + + assert get_messages_from_caplog(capture_logs) == snapshot( + [ + '{"status": "start", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', + '{"status": "finish", "action": "PUT", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"ttl": null}}', + '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', + '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "value": {"test": "value"}, "extra": {"hit": true}}', + '{"status": "start", "action": "DELETE", "collection": "test", "keys": "test"}', + '{"status": "finish", "action": "DELETE", "collection": "test", "keys": "test", "extra": {"deleted": true}}', + '{"status": "start", "action": "GET", "collection": "test", "keys": "test"}', + '{"status": "finish", "action": "GET", "collection": "test", "keys": "test", "extra": {"hit": false}}', + ] + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py index 529e39fd..f71fce63 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_passthrough_cache.py @@ -16,18 +16,14 @@ class TestPassthroughCacheWrapper(BaseStoreTests): - - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def primary_store(self) -> Generator[DiskStore, None, None]: - with tempfile.TemporaryDirectory() as temp_dir: - with DiskStore(directory=temp_dir, max_size=DISK_STORE_SIZE_LIMIT) as disk_store: - yield disk_store - + with tempfile.TemporaryDirectory() as temp_dir, DiskStore(directory=temp_dir, max_size=DISK_STORE_SIZE_LIMIT) as disk_store: + yield disk_store @pytest.fixture def cache_store(self, memory_store: MemoryStore) -> MemoryStore: return memory_store - @override @pytest.fixture diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py index efe487c2..6a31f566 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_collection.py @@ -10,8 +10,7 @@ class TestPrefixCollectionWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> PrefixCollectionsWrapper: - return PrefixCollectionsWrapper(key_value=memory_store, prefix='collection_prefix') + return PrefixCollectionsWrapper(key_value=memory_store, prefix="collection_prefix") diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py index 77e37107..8d64c8d2 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_prefix_key.py @@ -10,8 +10,7 @@ class TestPrefixKeyWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> PrefixKeysWrapper: - return PrefixKeysWrapper(key_value=memory_store, prefix='key_prefix') + return PrefixKeysWrapper(key_value=memory_store, prefix="key_prefix") diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py index 6a0f94fb..56e4c959 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_read_only.py @@ -10,68 +10,63 @@ class TestReadOnlyWrapper: - @pytest.fixture def memory_store(self) -> MemoryStore: return MemoryStore() - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> ReadOnlyWrapper: # Pre-populate the store with test data - memory_store.put(collection='test', key='test', value={'test': 'test'}) + memory_store.put(collection="test", key="test", value={"test": "test"}) return ReadOnlyWrapper(key_value=memory_store, raise_on_write=False) - def test_read_operations_allowed(self, memory_store: MemoryStore): # Pre-populate store - memory_store.put(collection='test', key='test', value={'test': 'value'}) - + memory_store.put(collection="test", key="test", value={"test": "value"}) + read_only_store = ReadOnlyWrapper(key_value=memory_store, raise_on_write=True) - + # Read operations should work - result = read_only_store.get(collection='test', key='test') - assert result == {'test': 'value'} - - results = read_only_store.get_many(collection='test', keys=['test']) - assert results == [{'test': 'value'}] - - (value, _) = read_only_store.ttl(collection='test', key='test') - assert value == {'test': 'value'} - + result = read_only_store.get(collection="test", key="test") + assert result == {"test": "value"} + + results = read_only_store.get_many(collection="test", keys=["test"]) + assert results == [{"test": "value"}] + + (value, _) = read_only_store.ttl(collection="test", key="test") + assert value == {"test": "value"} def test_write_operations_raise_error(self, memory_store: MemoryStore): read_only_store = ReadOnlyWrapper(key_value=memory_store, raise_on_write=True) - + # Write operations should raise ReadOnlyError with pytest.raises(ReadOnlyError): - read_only_store.put(collection='test', key='test', value={'test': 'value'}) - + read_only_store.put(collection="test", key="test", value={"test": "value"}) + with pytest.raises(ReadOnlyError): - read_only_store.put_many(collection='test', keys=['test'], values=[{'test': 'value'}]) - + read_only_store.put_many(collection="test", keys=["test"], values=[{"test": "value"}]) + with pytest.raises(ReadOnlyError): - read_only_store.delete(collection='test', key='test') - + read_only_store.delete(collection="test", key="test") + with pytest.raises(ReadOnlyError): - read_only_store.delete_many(collection='test', keys=['test']) - + read_only_store.delete_many(collection="test", keys=["test"]) def test_write_operations_silent_ignore(self, memory_store: MemoryStore): read_only_store = ReadOnlyWrapper(key_value=memory_store, raise_on_write=False) - + # Write operations should be silently ignored - read_only_store.put(collection='test', key='new_key', value={'test': 'value'}) - + read_only_store.put(collection="test", key="new_key", value={"test": "value"}) + # Verify nothing was written - result = memory_store.get(collection='test', key='new_key') + result = memory_store.get(collection="test", key="new_key") assert result is None - + # Delete should return False - deleted = read_only_store.delete(collection='test', key='test') + deleted = read_only_store.delete(collection="test", key="test") assert deleted is False - + # Delete many should return 0 - deleted_count = read_only_store.delete_many(collection='test', keys=['test']) + deleted_count = read_only_store.delete_many(collection="test", keys=["test"]) assert deleted_count == 0 diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py index 267414cf..aa12b5b8 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_retry.py @@ -11,75 +11,67 @@ class FailingStore(MemoryStore): """A store that fails a certain number of times before succeeding.""" - - def __init__(self, failures_before_success: int=2): + def __init__(self, failures_before_success: int = 2): super().__init__() self.failures_before_success = failures_before_success self.attempt_count = 0 - - def get(self, key: str, *, collection: str | None=None): + def get(self, key: str, *, collection: str | None = None): self.attempt_count += 1 if self.attempt_count <= self.failures_before_success: - msg = 'Simulated connection error' + msg = "Simulated connection error" raise ConnectionError(msg) return super().get(key=key, collection=collection) - def reset_attempts(self): self.attempt_count = 0 class TestRetryWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> RetryWrapper: return RetryWrapper(key_value=memory_store, max_retries=3, initial_delay=0.01) - def test_retry_succeeds_after_failures(self): failing_store = FailingStore(failures_before_success=2) retry_store = RetryWrapper(key_value=failing_store, max_retries=3, initial_delay=0.01) - + # Store a value first - retry_store.put(collection='test', key='test', value={'test': 'value'}) + retry_store.put(collection="test", key="test", value={"test": "value"}) failing_store.reset_attempts() - + # Should succeed after 2 failures - result = retry_store.get(collection='test', key='test') - assert result == {'test': 'value'} + result = retry_store.get(collection="test", key="test") + assert result == {"test": "value"} assert failing_store.attempt_count == 3 # 2 failures + 1 success - def test_retry_fails_after_max_retries(self): failing_store = FailingStore(failures_before_success=10) # More failures than max_retries retry_store = RetryWrapper(key_value=failing_store, max_retries=2, initial_delay=0.01) - + # Should fail after exhausting retries with pytest.raises(ConnectionError): - retry_store.get(collection='test', key='test') - + retry_store.get(collection="test", key="test") + assert failing_store.attempt_count == 3 # Initial attempt + 2 retries - def test_retry_with_different_exception(self): failing_store = FailingStore(failures_before_success=1) # Only retry on TimeoutError, not ConnectionError retry_store = RetryWrapper(key_value=failing_store, max_retries=3, initial_delay=0.01, retry_on=(TimeoutError,)) - + # Should fail immediately without retries with pytest.raises(ConnectionError): - retry_store.get(collection='test', key='test') - + retry_store.get(collection="test", key="test") + assert failing_store.attempt_count == 1 # No retries - def test_retry_no_failures(self, memory_store: MemoryStore): retry_store = RetryWrapper(key_value=memory_store, max_retries=3, initial_delay=0.01) - + # Normal operation should work without retries - retry_store.put(collection='test', key='test', value={'test': 'value'}) - result = retry_store.get(collection='test', key='test') - assert result == {'test': 'value'} + retry_store.put(collection="test", key="test", value={"test": "value"}) + result = retry_store.get(collection="test", key="test") + assert result == {"test": "value"} diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py index 5494eb27..9a75430d 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_routing.py @@ -11,37 +11,33 @@ from key_value.sync.code_gen.wrappers.routing import CollectionRoutingWrapper, RoutingWrapper from tests.code_gen.stores.base import BaseStoreTests -KEY_ONE = 'key1' -VALUE_ONE = {'this_key_1': 'this_value_1'} -COLLECTION_ONE = 'first' +KEY_ONE = "key1" +VALUE_ONE = {"this_key_1": "this_value_1"} +COLLECTION_ONE = "first" -KEY_TWO = 'key2' -VALUE_TWO = {'this_key_2': 'this_value_2'} -COLLECTION_TWO = 'second' +KEY_TWO = "key2" +VALUE_TWO = {"this_key_2": "this_value_2"} +COLLECTION_TWO = "second" -KEY_UNMAPPED = 'key3' -VALUE_UNMAPPED = {'this_key_3': 'this_value_3'} -COLLECTION_UNMAPPED = 'unmapped' +KEY_UNMAPPED = "key3" +VALUE_UNMAPPED = {"this_key_3": "this_value_3"} +COLLECTION_UNMAPPED = "unmapped" ALL_KEYS = [KEY_ONE, KEY_TWO, KEY_UNMAPPED] class TestRoutingWrapper(BaseStoreTests): - @pytest.fixture def second_store(self) -> MemoryStore: return MemoryStore() - @pytest.fixture def default_store(self) -> MemoryStore: return MemoryStore() - @pytest.fixture def store(self, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore) -> RoutingWrapper: first_store = memory_store - def route(collection: str | None) -> KeyValue | None: if collection == COLLECTION_ONE: @@ -49,9 +45,8 @@ def route(collection: str | None) -> KeyValue | None: if collection == COLLECTION_TWO: return second_store return None - + return RoutingWrapper(routing_function=route, default_store=default_store) - @pytest.fixture def store_with_data(self, store: RoutingWrapper) -> RoutingWrapper: @@ -59,74 +54,74 @@ def store_with_data(self, store: RoutingWrapper) -> RoutingWrapper: store.put(key=KEY_TWO, value=VALUE_TWO, collection=COLLECTION_TWO) store.put(key=KEY_UNMAPPED, value=VALUE_UNMAPPED, collection=COLLECTION_UNMAPPED) return store - @override - @pytest.mark.skip(reason='RoutingWrapper is unbounded') - def test_not_unbounded(self, store: BaseStore): - ... - + @pytest.mark.skip(reason="RoutingWrapper is unbounded") + def test_not_unbounded(self, store: BaseStore): ... - def test_routing_get_and_get_many(self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore): + def test_routing_get_and_get_many( + self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore + ): """Test basic routing sends gets""" assert memory_store.get(key=KEY_ONE, collection=COLLECTION_ONE) == VALUE_ONE assert memory_store.get(key=KEY_TWO, collection=COLLECTION_TWO) is None assert memory_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [VALUE_ONE, None, None] - + assert second_store.get(key=KEY_ONE, collection=COLLECTION_ONE) is None assert second_store.get(key=KEY_TWO, collection=COLLECTION_TWO) == VALUE_TWO assert second_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) is None assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, VALUE_TWO, None] - + assert default_store.get(key=KEY_ONE, collection=COLLECTION_ONE) is None assert default_store.get(key=KEY_TWO, collection=COLLECTION_TWO) is None assert default_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) == VALUE_UNMAPPED assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, VALUE_UNMAPPED] - - def test_routing_delete(self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore): + def test_routing_delete( + self, store_with_data: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore + ): """Test delete operations route correctly.""" - + assert store_with_data.get(key=KEY_ONE, collection=COLLECTION_ONE) == VALUE_ONE store_with_data.delete(key=KEY_ONE, collection=COLLECTION_ONE) assert memory_store.get(key=KEY_ONE, collection=COLLECTION_ONE) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [None, None, None] assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [None, None, None] assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_ONE) == [None, None, None] - + assert store_with_data.get(key=KEY_TWO, collection=COLLECTION_TWO) == VALUE_TWO store_with_data.delete(key=KEY_TWO, collection=COLLECTION_TWO) assert memory_store.get(key=KEY_TWO, collection=COLLECTION_TWO) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, None, None] assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, None, None] assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_TWO) == [None, None, None] - + assert store_with_data.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) == VALUE_UNMAPPED store_with_data.delete(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) assert memory_store.get(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) is None assert memory_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, None] assert second_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, None] assert default_store.get_many(keys=ALL_KEYS, collection=COLLECTION_UNMAPPED) == [None, None, None] - def test_routing_ttl(self, store: RoutingWrapper, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore): """Test TTL operations route correctly.""" key_one_ttl = 1800 key_two_ttl = 2700 key_unmapped_ttl = 7200 - + store.put(key=KEY_ONE, value=VALUE_ONE, collection=COLLECTION_ONE, ttl=key_one_ttl) store.put(key=KEY_TWO, value=VALUE_TWO, collection=COLLECTION_TWO, ttl=key_two_ttl) store.put(key=KEY_UNMAPPED, value=VALUE_UNMAPPED, collection=COLLECTION_UNMAPPED, ttl=key_unmapped_ttl) - + assert store.ttl(key=KEY_ONE, collection=COLLECTION_ONE) == (VALUE_ONE, IsFloat(approx=key_one_ttl)) assert store.ttl(key=KEY_TWO, collection=COLLECTION_TWO) == (VALUE_TWO, IsFloat(approx=key_two_ttl)) assert store.ttl(key=KEY_UNMAPPED, collection=COLLECTION_UNMAPPED) == (VALUE_UNMAPPED, IsFloat(approx=key_unmapped_ttl)) class TestCollectionRoutingWrapper(TestRoutingWrapper): - @pytest.fixture def store(self, memory_store: MemoryStore, second_store: MemoryStore, default_store: MemoryStore) -> CollectionRoutingWrapper: - return CollectionRoutingWrapper(collection_map={COLLECTION_ONE: memory_store, COLLECTION_TWO: second_store}, default_store=default_store) + return CollectionRoutingWrapper( + collection_map={COLLECTION_ONE: memory_store, COLLECTION_TWO: second_store}, default_store=default_store + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py index c8ee4512..f4e70e52 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_single_collection.py @@ -10,8 +10,7 @@ class TestSingleCollectionWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> SingleCollectionWrapper: - return SingleCollectionWrapper(key_value=memory_store, single_collection='test') + return SingleCollectionWrapper(key_value=memory_store, single_collection="test") diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py index 6dfe9e39..ddf43f0b 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_statistics.py @@ -10,7 +10,6 @@ class TestStatisticsWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> StatisticsWrapper: diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py index 175ffbbf..59ac6473 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_ttl_clamp.py @@ -11,45 +11,41 @@ class TestTTLClampWrapper(BaseStoreTests): - @override @pytest.fixture def store(self, memory_store: MemoryStore) -> TTLClampWrapper: return TTLClampWrapper(key_value=memory_store, min_ttl=0, max_ttl=100) - def test_put_below_min_ttl(self, memory_store: MemoryStore): ttl_clamp_store: TTLClampWrapper = TTLClampWrapper(key_value=memory_store, min_ttl=50, max_ttl=100) - - ttl_clamp_store.put(collection='test', key='test', value={'test': 'test'}, ttl=5) - assert ttl_clamp_store.get(collection='test', key='test') is not None - - (value, ttl) = ttl_clamp_store.ttl(collection='test', key='test') + + ttl_clamp_store.put(collection="test", key="test", value={"test": "test"}, ttl=5) + assert ttl_clamp_store.get(collection="test", key="test") is not None + + (value, ttl) = ttl_clamp_store.ttl(collection="test", key="test") assert value is not None assert ttl is not None assert ttl == IsFloat(approx=50) - def test_put_above_max_ttl(self, memory_store: MemoryStore): ttl_clamp_store: TTLClampWrapper = TTLClampWrapper(key_value=memory_store, min_ttl=0, max_ttl=100) - - ttl_clamp_store.put(collection='test', key='test', value={'test': 'test'}, ttl=1000) - assert ttl_clamp_store.get(collection='test', key='test') is not None - - (value, ttl) = ttl_clamp_store.ttl(collection='test', key='test') + + ttl_clamp_store.put(collection="test", key="test", value={"test": "test"}, ttl=1000) + assert ttl_clamp_store.get(collection="test", key="test") is not None + + (value, ttl) = ttl_clamp_store.ttl(collection="test", key="test") assert value is not None assert ttl is not None assert ttl == IsFloat(approx=100) - def test_put_missing_ttl(self, memory_store: MemoryStore): ttl_clamp_store: TTLClampWrapper = TTLClampWrapper(key_value=memory_store, min_ttl=0, max_ttl=100, missing_ttl=50) - - ttl_clamp_store.put(collection='test', key='test', value={'test': 'test'}, ttl=None) - assert ttl_clamp_store.get(collection='test', key='test') is not None - - (value, ttl) = ttl_clamp_store.ttl(collection='test', key='test') + + ttl_clamp_store.put(collection="test", key="test", value={"test": "test"}, ttl=None) + assert ttl_clamp_store.get(collection="test", key="test") is not None + + (value, ttl) = ttl_clamp_store.ttl(collection="test", key="test") assert value is not None assert ttl is not None - + assert ttl == IsFloat(approx=50) From 49cf6057683060b887bafeb402fdb4f251936c83 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Wed, 29 Oct 2025 17:39:11 +0000 Subject: [PATCH 03/11] refactor: move store-specific serialization adapters into store modules - Simplified shared serialization module to contain only SerializationAdapter base class - Moved FullJsonAdapter to Redis store module - Moved MongoDBAdapter to MongoDB store module - Moved ElasticsearchAdapter to Elasticsearch store module - Updated imports in all affected modules - Ran codegen to sync async changes to sync library - Applied linting fixes Co-authored-by: William Easton --- .../aio/stores/elasticsearch/store.py | 105 +++++- .../src/key_value/aio/stores/mongodb/store.py | 117 ++++++- .../src/key_value/aio/stores/redis/store.py | 41 ++- .../key_value/shared/utils/serialization.py | 310 +----------------- .../code_gen/stores/elasticsearch/store.py | 101 +++++- .../sync/code_gen/stores/mongodb/store.py | 117 ++++++- .../sync/code_gen/stores/redis/store.py | 41 ++- 7 files changed, 518 insertions(+), 314 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py index 740c74f4..0bd6d0b0 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py @@ -6,15 +6,15 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict from key_value.shared.utils.sanitize import ( ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, NUMBERS, sanitize_string, ) -from key_value.shared.utils.serialization import ElasticsearchAdapter, SerializationAdapter -from key_value.shared.utils.time_to_live import now_as_epoch +from key_value.shared.utils.serialization import SerializationAdapter +from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str from typing_extensions import override from key_value.aio.stores.base import ( @@ -85,6 +85,105 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." +class ElasticsearchAdapter(SerializationAdapter): + """Adapter for Elasticsearch with support for native and string storage modes. + + This adapter supports two storage modes: + - Native mode: Stores values as flattened dicts for efficient querying + - String mode: Stores values as JSON strings for backward compatibility + + Elasticsearch-specific features: + - Stores collection name in the document for multi-tenancy + - Uses ISO format for datetime fields + - Supports migration between storage modes + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the Elasticsearch adapter. + + Args: + native_storage: If True (default), store values as flattened dicts. + If False, store values as JSON strings. + """ + self.native_storage = native_storage + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: + """Convert a ManagedEntry to an Elasticsearch document. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: The collection name to store in the document. + + Returns: + An Elasticsearch document dict with collection, key, value, and metadata. + """ + document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}} + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["flattened"] = entry.value_as_dict + else: + document["value"]["string"] = entry.value_as_json + + if entry.created_at: + document["created_at"] = entry.created_at.isoformat() + if entry.expires_at: + document["expires_at"] = entry.expires_at.isoformat() + + return document + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert an Elasticsearch document back to a ManagedEntry. + + This method supports both native (flattened) and string storage modes, + trying the flattened field first and falling back to the string field. + This allows for seamless migration between storage modes. + + Args: + data: The Elasticsearch document to deserialize. + + Returns: + A ManagedEntry reconstructed from the document. + + Raises: + DeserializationError: If data is not a dict or is malformed. + """ + if not isinstance(data, dict): + msg = "Expected Elasticsearch document to be a dict" + raise DeserializationError(msg) + + document = data + value: dict[str, Any] = {} + + raw_value = document.get("value") + + # Try flattened field first, fall back to string field + if not raw_value or not isinstance(raw_value, dict): + msg = "Value field not found or invalid type" + raise DeserializationError(msg) + + if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + value = verify_dict(obj=value_flattened) + elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + if not isinstance(value_str, str): + msg = "Value in `value` field is not a string" + raise DeserializationError(msg) + value = load_from_json(value_str) + else: + msg = "Value field not found or invalid type" + raise DeserializationError(msg) + + created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at")) + expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at")) + + return ManagedEntry( + value=value, + created_at=created_at, + expires_at=expires_at, + ) + + class ElasticsearchStore( BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyCollectionStore, BaseCullStore, BaseContextManagerStore, BaseStore ): diff --git a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py index b1a3a8c0..6e80f80b 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py @@ -3,9 +3,9 @@ from typing import Any, overload from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string -from key_value.shared.utils.serialization import MongoDBAdapter, SerializationAdapter +from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import Self, override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore @@ -35,6 +35,119 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" +class MongoDBAdapter(SerializationAdapter): + """MongoDB-specific serialization adapter with native BSON datetime support. + + This adapter is optimized for MongoDB, storing: + - Native BSON datetime types for TTL indexing (created_at, expires_at) + - Values in value.object (native BSON) or value.string (JSON) fields + - Support for migration between native and string storage modes + + The native storage mode is recommended for new deployments as it allows + efficient querying of value fields, while string mode provides backward + compatibility with older data. + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the MongoDB adapter. + + Args: + native_storage: If True (default), store value as native BSON dict in value.object field. + If False, store as JSON string in value.string field for backward compatibility. + """ + self.native_storage = native_storage + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: # noqa: ARG002 + """Convert a ManagedEntry to a MongoDB document. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: The collection (unused, for interface compatibility). + + Returns: + A MongoDB document with key, value, and BSON datetime metadata. + """ + document: dict[str, Any] = {"key": key, "value": {}} + + # We convert to JSON even if we don't need to, this ensures that the value we were provided + # can be serialized to JSON which helps ensure compatibility across stores + json_str = entry.value_as_json + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["object"] = entry.value_as_dict + else: + document["value"]["string"] = json_str + + # Add metadata fields as BSON datetimes for TTL indexing + if entry.created_at: + document["created_at"] = entry.created_at + if entry.expires_at: + document["expires_at"] = entry.expires_at + + return document + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a MongoDB document back to a ManagedEntry. + + This method supports both native (object) and legacy (string) storage modes, + and properly handles BSON datetime types for metadata. + + Args: + data: The MongoDB document to deserialize. + + Returns: + A ManagedEntry reconstructed from the document. + + Raises: + DeserializationError: If data is not a dict or is malformed. + """ + if not isinstance(data, dict): + msg = "Expected MongoDB document to be a dict" + raise DeserializationError(msg) + + document = data + + if not (value_field := document.get("value")): + msg = "Value field not found" + raise DeserializationError(msg) + + if not isinstance(value_field, dict): + msg = "Expected `value` field to be an object" + raise DeserializationError(msg) + + value_holder: dict[str, Any] = verify_dict(obj=value_field) + + entry_data: dict[str, Any] = {} + + # Mongo stores datetimes without timezones as UTC so we mark them as UTC + # Import timezone here to avoid circular import + from key_value.shared.utils.time_to_live import timezone + + if created_at_datetime := document.get("created_at"): + if not isinstance(created_at_datetime, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) + + if expires_at_datetime := document.get("expires_at"): + if not isinstance(expires_at_datetime, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) + + # Support both native (object) and legacy (string) storage + if value_object := value_holder.get("object"): + return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) + + if value_string := value_holder.get("string"): + return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) + + msg = "Expected `value` field to be an object with `object` or `string` subfield" + raise DeserializationError(msg) + + class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): """MongoDB-based key-value store using pymongo.""" diff --git a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py index d8bee19a..52bfbf17 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py @@ -7,7 +7,7 @@ from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry -from key_value.shared.utils.serialization import FullJsonAdapter, SerializationAdapter +from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -22,6 +22,45 @@ PAGE_LIMIT = 10000 +class FullJsonAdapter(SerializationAdapter): + """Adapter that serializes entries as complete JSON strings. + + This adapter is suitable for Redis which works with string values. + It serializes the entire ManagedEntry (including all metadata) to a JSON string. + """ + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: # noqa: ARG002 + """Convert a ManagedEntry to a JSON string. + + Args: + key: The key (unused, for interface compatibility). + entry: The ManagedEntry to serialize. + collection: The collection (unused, for interface compatibility). + + Returns: + A JSON string containing the entry and all metadata. + """ + return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a JSON string back to a ManagedEntry. + + Args: + data: The JSON string to deserialize. + + Returns: + A ManagedEntry reconstructed from the JSON. + + Raises: + DeserializationError: If data is not a string or cannot be parsed. + """ + if not isinstance(data, str): + msg = "Expected data to be a JSON string" + raise DeserializationError(msg) + + return ManagedEntry.from_json(json_str=data, includes_metadata=True) + + class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" diff --git a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py index 86649455..3edf55fa 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py @@ -1,16 +1,14 @@ -"""Serialization adapters for converting ManagedEntry objects to/from store-specific formats. +"""Serialization adapter base class for converting ManagedEntry objects to/from store-specific formats. -This module provides a base SerializationAdapter ABC and common adapter implementations -that can be reused across different key-value stores. +This module provides the SerializationAdapter ABC that store implementations should use +to define their own serialization strategy. Store-specific adapter implementations +should be defined within their respective store modules. """ from abc import ABC, abstractmethod -from datetime import datetime from typing import Any -from key_value.shared.errors.key_value import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict -from key_value.shared.utils.time_to_live import try_parse_datetime_str +from key_value.shared.utils.managed_entry import ManagedEntry class SerializationAdapter(ABC): @@ -19,6 +17,9 @@ class SerializationAdapter(ABC): Adapters encapsulate the logic for converting between ManagedEntry objects and store-specific storage formats. This provides a consistent interface while allowing each store to optimize its serialization strategy. + + Store implementations should subclass this adapter and define their own + to_storage() and from_storage() methods within their store module. """ @abstractmethod @@ -49,298 +50,3 @@ def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: DeserializationError: If the data cannot be deserialized. """ ... - - -class FullJsonAdapter(SerializationAdapter): - """Adapter that serializes entries as complete JSON strings. - - This adapter is suitable for stores that work with string values, - such as Redis or Valkey. It serializes the entire ManagedEntry - (including all metadata) to a JSON string. - """ - - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: # noqa: ARG002 - """Convert a ManagedEntry to a JSON string. - - Args: - key: The key (unused, for interface compatibility). - entry: The ManagedEntry to serialize. - collection: The collection (unused, for interface compatibility). - - Returns: - A JSON string containing the entry and all metadata. - """ - return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a JSON string back to a ManagedEntry. - - Args: - data: The JSON string to deserialize. - - Returns: - A ManagedEntry reconstructed from the JSON. - - Raises: - DeserializationError: If data is not a string or cannot be parsed. - """ - if not isinstance(data, str): - msg = "Expected data to be a JSON string" - raise DeserializationError(msg) - - return ManagedEntry.from_json(json_str=data, includes_metadata=True) - - -class StringifiedDictAdapter(SerializationAdapter): - """Adapter that serializes entries as dicts with stringified values. - - This adapter is suitable for stores that prefer to store entries as - documents with the value field serialized as a JSON string. This allows - stores to index and query metadata fields while keeping the value opaque. - """ - - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: # noqa: ARG002 - """Convert a ManagedEntry to a dict with stringified value. - - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: The collection (unused, for interface compatibility). - - Returns: - A dict with key, stringified value, and metadata fields. - """ - return { - "key": key, - **entry.to_dict(include_metadata=True, include_expiration=True, include_creation=True, stringify_value=True), - } - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a dict with stringified value back to a ManagedEntry. - - Args: - data: The dict to deserialize. - - Returns: - A ManagedEntry reconstructed from the dict. - - Raises: - DeserializationError: If data is not a dict or is malformed. - """ - if not isinstance(data, dict): - msg = "Expected data to be a dict" - raise DeserializationError(msg) - - return ManagedEntry.from_dict(obj=data, expects_stringified_value=True, includes_metadata=True) - - -class MongoDBAdapter(SerializationAdapter): - """MongoDB-specific serialization adapter with native BSON datetime support. - - This adapter is optimized for MongoDB, storing: - - Native BSON datetime types for TTL indexing (created_at, expires_at) - - Values in value.object (native BSON) or value.string (JSON) fields - - Support for migration between native and string storage modes - - The native storage mode is recommended for new deployments as it allows - efficient querying of value fields, while string mode provides backward - compatibility with older data. - """ - - def __init__(self, *, native_storage: bool = True) -> None: - """Initialize the MongoDB adapter. - - Args: - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. - """ - self.native_storage = native_storage - - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: # noqa: ARG002 - """Convert a ManagedEntry to a MongoDB document. - - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: The collection (unused, for interface compatibility). - - Returns: - A MongoDB document with key, value, and BSON datetime metadata. - """ - document: dict[str, Any] = {"key": key, "value": {}} - - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores - json_str = entry.value_as_json - - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["object"] = entry.value_as_dict - else: - document["value"]["string"] = json_str - - # Add metadata fields as BSON datetimes for TTL indexing - if entry.created_at: - document["created_at"] = entry.created_at - if entry.expires_at: - document["expires_at"] = entry.expires_at - - return document - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry. - - This method supports both native (object) and legacy (string) storage modes, - and properly handles BSON datetime types for metadata. - - Args: - data: The MongoDB document to deserialize. - - Returns: - A ManagedEntry reconstructed from the document. - - Raises: - DeserializationError: If data is not a dict or is malformed. - """ - if not isinstance(data, dict): - msg = "Expected MongoDB document to be a dict" - raise DeserializationError(msg) - - document = data - - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - entry_data: dict[str, Any] = {} - - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - # Import timezone here to avoid circular import - from key_value.shared.utils.time_to_live import timezone - - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): - msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): - msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - # Support both native (object) and legacy (string) storage - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) - - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) - - -class ElasticsearchAdapter(SerializationAdapter): - """Adapter for Elasticsearch with support for native and string storage modes. - - This adapter supports two storage modes: - - Native mode: Stores values as flattened dicts for efficient querying - - String mode: Stores values as JSON strings for backward compatibility - - Elasticsearch-specific features: - - Stores collection name in the document for multi-tenancy - - Uses ISO format for datetime fields - - Supports migration between storage modes - """ - - def __init__(self, *, native_storage: bool = True) -> None: - """Initialize the Elasticsearch adapter. - - Args: - native_storage: If True (default), store values as flattened dicts. - If False, store values as JSON strings. - """ - self.native_storage = native_storage - - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: - """Convert a ManagedEntry to an Elasticsearch document. - - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: The collection name to store in the document. - - Returns: - An Elasticsearch document dict with collection, key, value, and metadata. - """ - document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}} - - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["flattened"] = entry.value_as_dict - else: - document["value"]["string"] = entry.value_as_json - - if entry.created_at: - document["created_at"] = entry.created_at.isoformat() - if entry.expires_at: - document["expires_at"] = entry.expires_at.isoformat() - - return document - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert an Elasticsearch document back to a ManagedEntry. - - This method supports both native (flattened) and string storage modes, - trying the flattened field first and falling back to the string field. - This allows for seamless migration between storage modes. - - Args: - data: The Elasticsearch document to deserialize. - - Returns: - A ManagedEntry reconstructed from the document. - - Raises: - DeserializationError: If data is not a dict or is malformed. - """ - if not isinstance(data, dict): - msg = "Expected Elasticsearch document to be a dict" - raise DeserializationError(msg) - - document = data - value: dict[str, Any] = {} - - raw_value = document.get("value") - - # Try flattened field first, fall back to string field - if not raw_value or not isinstance(raw_value, dict): - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - value = verify_dict(obj=value_flattened) - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - if not isinstance(value_str, str): - msg = "Value in `value` field is not a string" - raise DeserializationError(msg) - value = load_from_json(value_str) - else: - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at")) - - return ManagedEntry( - value=value, - created_at=created_at, - expires_at=expires_at, - ) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py index c56b0a23..ee0c6a8b 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py @@ -9,10 +9,10 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, NUMBERS, sanitize_string -from key_value.shared.utils.serialization import ElasticsearchAdapter, SerializationAdapter -from key_value.shared.utils.time_to_live import now_as_epoch +from key_value.shared.utils.serialization import SerializationAdapter +from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str from typing_extensions import override from key_value.sync.code_gen.stores.base import ( @@ -65,6 +65,101 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." +class ElasticsearchAdapter(SerializationAdapter): + """Adapter for Elasticsearch with support for native and string storage modes. + + This adapter supports two storage modes: + - Native mode: Stores values as flattened dicts for efficient querying + - String mode: Stores values as JSON strings for backward compatibility + + Elasticsearch-specific features: + - Stores collection name in the document for multi-tenancy + - Uses ISO format for datetime fields + - Supports migration between storage modes + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the Elasticsearch adapter. + + Args: + native_storage: If True (default), store values as flattened dicts. + If False, store values as JSON strings. + """ + self.native_storage = native_storage + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: + """Convert a ManagedEntry to an Elasticsearch document. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: The collection name to store in the document. + + Returns: + An Elasticsearch document dict with collection, key, value, and metadata. + """ + document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}} + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["flattened"] = entry.value_as_dict + else: + document["value"]["string"] = entry.value_as_json + + if entry.created_at: + document["created_at"] = entry.created_at.isoformat() + if entry.expires_at: + document["expires_at"] = entry.expires_at.isoformat() + + return document + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert an Elasticsearch document back to a ManagedEntry. + + This method supports both native (flattened) and string storage modes, + trying the flattened field first and falling back to the string field. + This allows for seamless migration between storage modes. + + Args: + data: The Elasticsearch document to deserialize. + + Returns: + A ManagedEntry reconstructed from the document. + + Raises: + DeserializationError: If data is not a dict or is malformed. + """ + if not isinstance(data, dict): + msg = "Expected Elasticsearch document to be a dict" + raise DeserializationError(msg) + + document = data + value: dict[str, Any] = {} + + raw_value = document.get("value") + + # Try flattened field first, fall back to string field + if not raw_value or not isinstance(raw_value, dict): + msg = "Value field not found or invalid type" + raise DeserializationError(msg) + + if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + value = verify_dict(obj=value_flattened) + elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + if not isinstance(value_str, str): + msg = "Value in `value` field is not a string" + raise DeserializationError(msg) + value = load_from_json(value_str) + else: + msg = "Value field not found or invalid type" + raise DeserializationError(msg) + + created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at")) + expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at")) + + return ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) + + class ElasticsearchStore( BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDestroyCollectionStore, BaseCullStore, BaseContextManagerStore, BaseStore ): diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py index 8c3e6c48..5c1a1055 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py @@ -6,9 +6,9 @@ from typing import Any, overload from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string -from key_value.shared.utils.serialization import MongoDBAdapter, SerializationAdapter +from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import Self, override from key_value.sync.code_gen.stores.base import ( @@ -42,6 +42,119 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" +class MongoDBAdapter(SerializationAdapter): + """MongoDB-specific serialization adapter with native BSON datetime support. + + This adapter is optimized for MongoDB, storing: + - Native BSON datetime types for TTL indexing (created_at, expires_at) + - Values in value.object (native BSON) or value.string (JSON) fields + - Support for migration between native and string storage modes + + The native storage mode is recommended for new deployments as it allows + efficient querying of value fields, while string mode provides backward + compatibility with older data. + """ + + def __init__(self, *, native_storage: bool = True) -> None: + """Initialize the MongoDB adapter. + + Args: + native_storage: If True (default), store value as native BSON dict in value.object field. + If False, store as JSON string in value.string field for backward compatibility. + """ + self.native_storage = native_storage + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: + """Convert a ManagedEntry to a MongoDB document. + + Args: + key: The key associated with this entry. + entry: The ManagedEntry to serialize. + collection: The collection (unused, for interface compatibility). + + Returns: + A MongoDB document with key, value, and BSON datetime metadata. + """ + document: dict[str, Any] = {"key": key, "value": {}} + + # We convert to JSON even if we don't need to, this ensures that the value we were provided + # can be serialized to JSON which helps ensure compatibility across stores + json_str = entry.value_as_json + + # Store in appropriate field based on mode + if self.native_storage: + document["value"]["object"] = entry.value_as_dict + else: + document["value"]["string"] = json_str + + # Add metadata fields as BSON datetimes for TTL indexing + if entry.created_at: + document["created_at"] = entry.created_at + if entry.expires_at: + document["expires_at"] = entry.expires_at + + return document + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a MongoDB document back to a ManagedEntry. + + This method supports both native (object) and legacy (string) storage modes, + and properly handles BSON datetime types for metadata. + + Args: + data: The MongoDB document to deserialize. + + Returns: + A ManagedEntry reconstructed from the document. + + Raises: + DeserializationError: If data is not a dict or is malformed. + """ + if not isinstance(data, dict): + msg = "Expected MongoDB document to be a dict" + raise DeserializationError(msg) + + document = data + + if not (value_field := document.get("value")): + msg = "Value field not found" + raise DeserializationError(msg) + + if not isinstance(value_field, dict): + msg = "Expected `value` field to be an object" + raise DeserializationError(msg) + + value_holder: dict[str, Any] = verify_dict(obj=value_field) + + entry_data: dict[str, Any] = {} + + # Mongo stores datetimes without timezones as UTC so we mark them as UTC + # Import timezone here to avoid circular import + from key_value.shared.utils.time_to_live import timezone + + if created_at_datetime := document.get("created_at"): + if not isinstance(created_at_datetime, datetime): + msg = "Expected `created_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) + + if expires_at_datetime := document.get("expires_at"): + if not isinstance(expires_at_datetime, datetime): + msg = "Expected `expires_at` field to be a datetime" + raise DeserializationError(msg) + entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) + + # Support both native (object) and legacy (string) storage + if value_object := value_holder.get("object"): + return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) + + if value_string := value_holder.get("string"): + return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) + + msg = "Expected `value` field to be an object with `object` or `string` subfield" + raise DeserializationError(msg) + + class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): """MongoDB-based key-value store using pymongo.""" diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py index 76238fa5..d2f8659a 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py @@ -10,7 +10,7 @@ from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry -from key_value.shared.utils.serialization import FullJsonAdapter, SerializationAdapter +from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -25,6 +25,45 @@ PAGE_LIMIT = 10000 +class FullJsonAdapter(SerializationAdapter): + """Adapter that serializes entries as complete JSON strings. + + This adapter is suitable for Redis which works with string values. + It serializes the entire ManagedEntry (including all metadata) to a JSON string. + """ + + def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: + """Convert a ManagedEntry to a JSON string. + + Args: + key: The key (unused, for interface compatibility). + entry: The ManagedEntry to serialize. + collection: The collection (unused, for interface compatibility). + + Returns: + A JSON string containing the entry and all metadata. + """ + return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) + + def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: + """Convert a JSON string back to a ManagedEntry. + + Args: + data: The JSON string to deserialize. + + Returns: + A ManagedEntry reconstructed from the JSON. + + Raises: + DeserializationError: If data is not a string or cannot be parsed. + """ + if not isinstance(data, str): + msg = "Expected data to be a JSON string" + raise DeserializationError(msg) + + return ManagedEntry.from_json(json_str=data, includes_metadata=True) + + class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" From 714b0cba15cf94267d67c3140732f9fa780624e6 Mon Sep 17 00:00:00 2001 From: William Easton Date: Wed, 29 Oct 2025 20:56:23 -0500 Subject: [PATCH 04/11] Refactor serialization --- .../src/key_value/aio/stores/base.py | 13 +- .../aio/stores/elasticsearch/store.py | 126 ++++--------- .../src/key_value/aio/stores/mongodb/store.py | 174 ++++++------------ .../src/key_value/aio/stores/redis/store.py | 64 +------ .../elasticsearch/test_elasticsearch.py | 24 +-- .../tests/stores/mongodb/test_mongodb.py | 13 +- .../tests/stores/redis/test_redis.py | 91 ++++++--- key-value/key-value-shared/pyproject.toml | 5 - .../key_value/shared/utils/managed_entry.py | 28 +-- .../key_value/shared/utils/serialization.py | 131 ++++++++++--- .../key_value/sync/code_gen/stores/base.py | 13 +- .../code_gen/stores/elasticsearch/store.py | 122 ++++-------- .../sync/code_gen/stores/mongodb/store.py | 166 +++++------------ .../sync/code_gen/stores/redis/store.py | 64 +------ .../elasticsearch/test_elasticsearch.py | 46 ++--- .../code_gen/stores/mongodb/test_mongodb.py | 13 +- .../tests/code_gen/stores/redis/test_redis.py | 76 +++++--- 17 files changed, 474 insertions(+), 695 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/base.py b/key-value/key-value-aio/src/key_value/aio/stores/base.py index 074f7850..b23bcec9 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/base.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/base.py @@ -14,6 +14,7 @@ from key_value.shared.errors import StoreSetupError from key_value.shared.type_checking.bear_spray import bear_enforce from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from key_value.shared.utils.time_to_live import prepare_entry_timestamps from typing_extensions import Self, override @@ -67,6 +68,8 @@ class BaseStore(AsyncKeyValueProtocol, ABC): _setup_collection_locks: defaultdict[str, Lock] _setup_collection_complete: defaultdict[str, bool] + _serialization_adapter: SerializationAdapter + _seed: FROZEN_SEED_DATA_TYPE default_collection: str @@ -91,6 +94,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP self.default_collection = default_collection or DEFAULT_COLLECTION_NAME + self._serialization_adapter = BasicSerializationAdapter() + if not hasattr(self, "_stable_api"): self._stable_api = False @@ -286,9 +291,9 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non collection = collection or self.default_collection await self.setup_collection(collection=collection) - created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl) + created_at, _, expires_at = prepare_entry_timestamps(ttl=ttl) - managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) + managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) await self._put_managed_entry( collection=collection, @@ -316,9 +321,7 @@ async def put_many( created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl) - managed_entries: list[ManagedEntry] = [ - ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values - ] + managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values] await self._put_managed_entries( collection=collection, diff --git a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py index 0bd6d0b0..c395279e 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py @@ -6,7 +6,7 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ( ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, @@ -14,7 +14,7 @@ sanitize_string, ) from key_value.shared.utils.serialization import SerializationAdapter -from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str +from key_value.shared.utils.time_to_live import now_as_epoch from typing_extensions import override from key_value.aio.stores.base import ( @@ -85,103 +85,50 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." -class ElasticsearchAdapter(SerializationAdapter): - """Adapter for Elasticsearch with support for native and string storage modes. +class ElasticsearchSerializationAdapter(SerializationAdapter): + """Adapter for Elasticsearch with support for native and string storage modes.""" - This adapter supports two storage modes: - - Native mode: Stores values as flattened dicts for efficient querying - - String mode: Stores values as JSON strings for backward compatibility - - Elasticsearch-specific features: - - Stores collection name in the document for multi-tenancy - - Uses ISO format for datetime fields - - Supports migration between storage modes - """ + _native_storage: bool def __init__(self, *, native_storage: bool = True) -> None: """Initialize the Elasticsearch adapter. Args: native_storage: If True (default), store values as flattened dicts. - If False, store values as JSON strings. + If False, store values as JSON strings. """ - self.native_storage = native_storage + super().__init__() - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: - """Convert a ManagedEntry to an Elasticsearch document. + self._native_storage = native_storage + self._date_format = "isoformat" + self._value_format = "dict" if native_storage else "string" - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: The collection name to store in the document. + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - Returns: - An Elasticsearch document dict with collection, key, value, and metadata. - """ - document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}} + data["value"] = {} - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["flattened"] = entry.value_as_dict + if self._native_storage: + data["value"]["flattened"] = value else: - document["value"]["string"] = entry.value_as_json - - if entry.created_at: - document["created_at"] = entry.created_at.isoformat() - if entry.expires_at: - document["expires_at"] = entry.expires_at.isoformat() - - return document + data["value"]["string"] = value - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert an Elasticsearch document back to a ManagedEntry. + return data - This method supports both native (flattened) and string storage modes, - trying the flattened field first and falling back to the string field. - This allows for seamless migration between storage modes. - - Args: - data: The Elasticsearch document to deserialize. - - Returns: - A ManagedEntry reconstructed from the document. + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - Raises: - DeserializationError: If data is not a dict or is malformed. - """ - if not isinstance(data, dict): - msg = "Expected Elasticsearch document to be a dict" - raise DeserializationError(msg) - - document = data - value: dict[str, Any] = {} - - raw_value = document.get("value") - - # Try flattened field first, fall back to string field - if not raw_value or not isinstance(raw_value, dict): - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - value = verify_dict(obj=value_flattened) - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - if not isinstance(value_str, str): - msg = "Value in `value` field is not a string" - raise DeserializationError(msg) - value = load_from_json(value_str) + if flattened := value.get("flattened"): + data["value"] = flattened + elif string := value.get("string"): + data["value"] = string else: - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at")) + msg = "Value field not found in Elasticsearch document" + raise DeserializationError(message=msg) - return ManagedEntry( - value=value, - created_at=created_at, - expires_at=expires_at, - ) + return data class ElasticsearchStore( @@ -262,7 +209,7 @@ def __init__( self._index_prefix = index_prefix self._native_storage = native_storage self._is_serverless = False - self._adapter = ElasticsearchAdapter(native_storage=native_storage) + self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -315,7 +262,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None try: - return self._adapter.from_storage(data=source) + return self._adapter.load_dict(data=source) except DeserializationError: return None @@ -348,7 +295,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> continue try: - entries_by_id[doc_id] = self._adapter.from_storage(data=source) + entries_by_id[doc_id] = self._adapter.load_dict(data=source) except DeserializationError as e: logger.error( "Failed to deserialize Elasticsearch document in batch operation", @@ -379,10 +326,7 @@ async def _put_managed_entry( index_name: str = self._sanitize_index_name(collection=collection) document_id: str = self._sanitize_document_id(key=key) - document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - if not isinstance(document, dict): - msg = "Elasticsearch adapter must return dict" - raise TypeError(msg) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) try: _ = await self._client.index( @@ -420,12 +364,10 @@ async def _put_managed_entries( index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id) - document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - if not isinstance(document, dict): - msg = "Elasticsearch adapter must return dict" - raise TypeError(msg) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) operations.extend([index_action, document]) + try: _ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType] except ElasticsearchSerializationError as e: diff --git a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py index 6e80f80b..34d6e2e1 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py @@ -1,9 +1,9 @@ from collections.abc import Sequence -from datetime import datetime +from datetime import datetime, timezone from typing import Any, overload from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import Self, override @@ -11,7 +11,7 @@ from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore try: - from pymongo import AsyncMongoClient + from pymongo import AsyncMongoClient, UpdateOne from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.database import AsyncDatabase from pymongo.results import DeleteResult # noqa: TC002 @@ -35,117 +35,56 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" -class MongoDBAdapter(SerializationAdapter): - """MongoDB-specific serialization adapter with native BSON datetime support. +class MongoDBSerializationAdapter(SerializationAdapter): + """Adapter for MongoDB with support for native and string storage modes.""" - This adapter is optimized for MongoDB, storing: - - Native BSON datetime types for TTL indexing (created_at, expires_at) - - Values in value.object (native BSON) or value.string (JSON) fields - - Support for migration between native and string storage modes - - The native storage mode is recommended for new deployments as it allows - efficient querying of value fields, while string mode provides backward - compatibility with older data. - """ + _native_storage: bool def __init__(self, *, native_storage: bool = True) -> None: - """Initialize the MongoDB adapter. - - Args: - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. - """ - self.native_storage = native_storage + """Initialize the MongoDB adapter.""" + super().__init__() - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: # noqa: ARG002 - """Convert a ManagedEntry to a MongoDB document. + self._native_storage = native_storage + self._date_format = "datetime" + self._value_format = "dict" if native_storage else "string" - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: The collection (unused, for interface compatibility). - - Returns: - A MongoDB document with key, value, and BSON datetime metadata. - """ - document: dict[str, Any] = {"key": key, "value": {}} + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores - json_str = entry.value_as_json + data["value"] = {} - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["object"] = entry.value_as_dict + if self._native_storage: + data["value"]["object"] = value else: - document["value"]["string"] = json_str - - # Add metadata fields as BSON datetimes for TTL indexing - if entry.created_at: - document["created_at"] = entry.created_at - if entry.expires_at: - document["expires_at"] = entry.expires_at - - return document - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry. - - This method supports both native (object) and legacy (string) storage modes, - and properly handles BSON datetime types for metadata. - - Args: - data: The MongoDB document to deserialize. + data["value"]["string"] = value - Returns: - A ManagedEntry reconstructed from the document. - - Raises: - DeserializationError: If data is not a dict or is malformed. - """ - if not isinstance(data, dict): - msg = "Expected MongoDB document to be a dict" - raise DeserializationError(msg) - - document = data + return data - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - entry_data: dict[str, Any] = {} + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - # Import timezone here to avoid circular import - from key_value.shared.utils.time_to_live import timezone + if value_object := value.get("object"): + data["value"] = value_object + elif value_string := value.get("string"): + data["value"] = value_string + else: + msg = "Value field not found in MongoDB document" + raise DeserializationError(message=msg) - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): + if date_created := data.get("created_at"): + if not isinstance(date_created, datetime): msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): + raise DeserializationError(message=msg) + data["created_at"] = date_created.replace(tzinfo=timezone.utc) + if date_expires := data.get("expires_at"): + if not isinstance(date_expires, datetime): msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - # Support both native (object) and legacy (string) storage - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) + raise DeserializationError(message=msg) + data["expires_at"] = date_expires.replace(tzinfo=timezone.utc) - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) + return data class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): @@ -232,7 +171,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} - self._adapter = MongoDBAdapter(native_storage=native_storage) + self._adapter = MongoDBSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -290,7 +229,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if doc := await self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): try: - return self._adapter.from_storage(data=doc) + return self._adapter.load_dict(data=doc) except DeserializationError: return None @@ -311,7 +250,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> async for doc in cursor: if key := doc.get("key"): try: - managed_entries_by_key[key] = self._adapter.from_storage(data=doc) + managed_entries_by_key[key] = self._adapter.load_dict(data=doc) except DeserializationError: managed_entries_by_key[key] = None @@ -325,17 +264,19 @@ async def _put_managed_entry( collection: str, managed_entry: ManagedEntry, ) -> None: - mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(mongo_doc, dict): - msg = "MongoDB adapter must return dict" - raise TypeError(msg) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) sanitized_collection = self._sanitize_collection_name(collection=collection) _ = await self._collections_by_name[sanitized_collection].update_one( filter={"key": key}, - update={"$set": mongo_doc}, + update={ + "$set": { + "collection": collection, + "key": key, + **mongo_doc, + } + }, upsert=True, ) @@ -355,21 +296,20 @@ async def _put_managed_entries( sanitized_collection = self._sanitize_collection_name(collection=collection) - # Use bulk_write for efficient batch operations - from pymongo import UpdateOne - operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(mongo_doc, dict): - msg = "MongoDB adapter must return dict" - raise TypeError(msg) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) operations.append( UpdateOne( filter={"key": key}, - update={"$set": mongo_doc}, + update={ + "$set": { + "collection": collection, + "key": key, + **mongo_doc, + } + }, upsert=True, ) ) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py index 52bfbf17..13e60655 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/redis/store.py @@ -7,7 +7,7 @@ from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry -from key_value.shared.utils.serialization import SerializationAdapter +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -22,45 +22,6 @@ PAGE_LIMIT = 10000 -class FullJsonAdapter(SerializationAdapter): - """Adapter that serializes entries as complete JSON strings. - - This adapter is suitable for Redis which works with string values. - It serializes the entire ManagedEntry (including all metadata) to a JSON string. - """ - - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: # noqa: ARG002 - """Convert a ManagedEntry to a JSON string. - - Args: - key: The key (unused, for interface compatibility). - entry: The ManagedEntry to serialize. - collection: The collection (unused, for interface compatibility). - - Returns: - A JSON string containing the entry and all metadata. - """ - return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a JSON string back to a ManagedEntry. - - Args: - data: The JSON string to deserialize. - - Returns: - A ManagedEntry reconstructed from the JSON. - - Raises: - DeserializationError: If data is not a string or cannot be parsed. - """ - if not isinstance(data, str): - msg = "Expected data to be a JSON string" - raise DeserializationError(msg) - - return ManagedEntry.from_json(json_str=data, includes_metadata=True) - - class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" @@ -122,7 +83,7 @@ def __init__( ) self._stable_api = True - self._adapter = FullJsonAdapter() + self._adapter = BasicSerializationAdapter(date_format="isoformat", value_format="dict") super().__init__(default_collection=default_collection) @@ -136,7 +97,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None try: - return self._adapter.from_storage(data=redis_response) + return self._adapter.load_json(json_str=redis_response) except DeserializationError: return None @@ -153,7 +114,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> for redis_response in redis_responses: if isinstance(redis_response, str): try: - entries.append(self._adapter.from_storage(data=redis_response)) + entries.append(self._adapter.load_json(json_str=redis_response)) except DeserializationError: entries.append(None) else: @@ -171,11 +132,7 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(json_value, str): - msg = "Redis adapter must return str" - raise TypeError(msg) + json_value: str = self._adapter.dump_json(entry=managed_entry) if managed_entry.ttl is not None: # Redis does not support <= 0 TTLs @@ -203,10 +160,7 @@ async def _put_managed_entries( # If there is no TTL, we can just do a simple mset mapping: dict[str, str] = {} for key, managed_entry in zip(keys, managed_entries, strict=True): - json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - if not isinstance(json_value, str): - msg = "Redis adapter must return str" - raise TypeError(msg) + json_value = self._adapter.dump_json(entry=managed_entry) mapping[compound_key(collection=collection, key=key)] = json_value await self._client.mset(mapping=mapping) @@ -221,11 +175,7 @@ async def _put_managed_entries( for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(json_value, str): - msg = "Redis adapter must return str" - raise TypeError(msg) + json_value = self._adapter.dump_json(entry=managed_entry) pipeline.setex(name=combo_key, time=ttl_seconds, value=json_value) diff --git a/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py b/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py index 079d8a23..ec217a5d 100644 --- a/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py +++ b/key-value/key-value-aio/tests/stores/elasticsearch/test_elasticsearch.py @@ -11,7 +11,7 @@ from key_value.aio.stores.base import BaseStore from key_value.aio.stores.elasticsearch import ElasticsearchStore -from key_value.aio.stores.elasticsearch.store import managed_entry_to_document, source_to_managed_entry +from key_value.aio.stores.elasticsearch.store import ElasticsearchSerializationAdapter from tests.conftest import docker_container, should_skip_docker_tests from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -55,19 +55,18 @@ def test_managed_entry_document_conversion(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry) + adapter = ElasticsearchSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "collection": "test_collection", - "key": "test_key", "value": {"string": '{"test": "test"}'}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", } ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -80,19 +79,18 @@ def test_managed_entry_document_conversion_native_storage(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry, native_storage=True) + adapter = ElasticsearchSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { - "collection": "test_collection", - "key": "test_key", "value": {"flattened": {"test": "test"}}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", } ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -170,8 +168,6 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"flattened": {"name": "Alice", "age": 30}}, "created_at": IsStr(min_length=20, max_length=40), } @@ -182,8 +178,6 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"flattened": {"name": "Bob", "age": 25}}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), @@ -231,8 +225,6 @@ async def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_c response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"string": '{"age": 30, "name": "Alice"}'}, "created_at": IsStr(min_length=20, max_length=40), } @@ -243,8 +235,6 @@ async def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_c response = await es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"string": '{"age": 25, "name": "Bob"}'}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), diff --git a/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py b/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py index a638a152..6c3f74de 100644 --- a/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py @@ -14,7 +14,7 @@ from key_value.aio.stores.base import BaseStore from key_value.aio.stores.mongodb import MongoDBStore -from key_value.aio.stores.mongodb.store import document_to_managed_entry, managed_entry_to_document +from key_value.aio.stores.mongodb.store import MongoDBSerializationAdapter from tests.conftest import docker_container, should_skip_docker_tests from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -50,7 +50,9 @@ def test_managed_entry_document_conversion_native_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) + + adapter = MongoDBSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { @@ -61,7 +63,7 @@ def test_managed_entry_document_conversion_native_mode(): } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -74,7 +76,8 @@ def test_managed_entry_document_conversion_legacy_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) + adapter = MongoDBSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { @@ -85,7 +88,7 @@ def test_managed_entry_document_conversion_legacy_mode(): } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at diff --git a/key-value/key-value-aio/tests/stores/redis/test_redis.py b/key-value/key-value-aio/tests/stores/redis/test_redis.py index 407fc799..c7e7c6bb 100644 --- a/key-value/key-value-aio/tests/stores/redis/test_redis.py +++ b/key-value/key-value-aio/tests/stores/redis/test_redis.py @@ -1,17 +1,16 @@ +import json from collections.abc import AsyncGenerator -from datetime import datetime, timedelta, timezone +from typing import Any import pytest -from dirty_equals import IsFloat +from dirty_equals import IsDatetime from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true -from key_value.shared.utils.managed_entry import ManagedEntry from redis.asyncio.client import Redis from typing_extensions import override from key_value.aio.stores.base import BaseStore from key_value.aio.stores.redis import RedisStore -from key_value.aio.stores.redis.store import json_to_managed_entry, managed_entry_to_json from tests.conftest import docker_container, should_skip_docker_tests from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -28,25 +27,6 @@ ] -def test_managed_entry_document_conversion(): - created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) - expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_json(managed_entry=managed_entry) - - assert document == snapshot( - '{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}' - ) - - round_trip_managed_entry = json_to_managed_entry(json_str=document) - - assert round_trip_managed_entry.value == managed_entry.value - assert round_trip_managed_entry.created_at == created_at - assert round_trip_managed_entry.ttl == IsFloat(lt=0) - assert round_trip_managed_entry.expires_at == expires_at - - async def ping_redis() -> bool: client: Redis = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) try: @@ -59,6 +39,10 @@ class RedisFailedToStartError(Exception): pass +def get_client_from_store(store: RedisStore) -> Redis: + return store._client # pyright: ignore[reportPrivateUsage] + + @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestRedisStore(ContextManagerStoreTestMixin, BaseStoreTests): @pytest.fixture(autouse=True, scope="session", params=REDIS_VERSIONS_TO_TEST) @@ -78,14 +62,18 @@ async def store(self, setup_redis: RedisStore) -> RedisStore: """Create a Redis store for testing.""" # Create the store with test database redis_store = RedisStore(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) - _ = await redis_store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = await get_client_from_store(store=redis_store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] return redis_store + @pytest.fixture + def redis_client(self, store: RedisStore) -> Redis: + return get_client_from_store(store=store) + async def test_redis_url_connection(self): """Test Redis store creation with URL.""" redis_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" store = RedisStore(url=redis_url) - _ = await store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = await get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] await store.put(collection="test", key="url_test", value={"test": "value"}) result = await store.get(collection="test", key="url_test") assert result == {"test": "value"} @@ -97,11 +85,62 @@ async def test_redis_client_connection(self): client = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) store = RedisStore(client=client) - _ = await store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = await get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] await store.put(collection="test", key="client_test", value={"test": "value"}) result = await store.get(collection="test", key="client_test") assert result == {"test": "value"} + async def test_redis_document_format(self, store: RedisStore, redis_client: Redis): + """Test Redis store document format.""" + await store.put(collection="test", key="document_format_test_1", value={"test_1": "value_1"}) + await store.put(collection="test", key="document_format_test_2", value={"test_2": "value_2"}, ttl=10) + + raw_documents = await redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) + raw_documents_dicts: list[dict[str, Any]] = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + { + "created_at": IsDatetime(iso_string=True), + "value": {"test_1": "value_1"}, + }, + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"test_2": "value_2"}, + }, + ] + ) + + await store.put_many( + collection="test", + keys=["document_format_test_3", "document_format_test_4"], + values=[{"test_3": "value_3"}, {"test_4": "value_4"}], + ttl=10, + ) + raw_documents = await redis_client.mget(keys=["test::document_format_test_3", "test::document_format_test_4"]) + raw_documents_dicts = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"test_3": "value_3"}, + }, + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"test_4": "value_4"}, + }, + ] + ) + + await store.put(collection="test", key="document_format_test", value={"test": "value"}, ttl=10) + raw_document = await redis_client.get(name="test::document_format_test") + raw_document_dict = json.loads(raw_document) + assert raw_document_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test": "value"}} + ) + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... diff --git a/key-value/key-value-shared/pyproject.toml b/key-value/key-value-shared/pyproject.toml index f4054e96..67d570fe 100644 --- a/key-value/key-value-shared/pyproject.toml +++ b/key-value/key-value-shared/pyproject.toml @@ -56,8 +56,3 @@ extend="../../pyproject.toml" [tool.pyright] extends = "../../pyproject.toml" - -executionEnvironments = [ - { root = "tests", reportPrivateUsage = false, extraPaths = ["src"]}, - { root = "src" } -] diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index 90cb5490..9128610d 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -7,7 +7,7 @@ from typing_extensions import Self from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.time_to_live import now, now_plus, prepare_ttl, try_parse_datetime_str +from key_value.shared.utils.time_to_live import now, prepare_ttl, seconds_to, try_parse_datetime_str @dataclass(kw_only=True) @@ -22,22 +22,20 @@ class ManagedEntry: value: Mapping[str, Any] created_at: datetime | None = field(default=None) - ttl: float | None = field(default=None) expires_at: datetime | None = field(default=None) - def __post_init__(self) -> None: - if self.ttl is not None and self.expires_at is None: - self.expires_at = now_plus(seconds=self.ttl) - - elif self.expires_at is not None and self.ttl is None: - self.recalculate_ttl() - @property def is_expired(self) -> bool: if self.expires_at is None: return False return self.expires_at <= now() + @property + def ttl(self) -> float | None: + if self.expires_at is None: + return None + return seconds_to(datetime=self.expires_at) + @property def value_as_json(self) -> str: """Return the value as a JSON string.""" @@ -45,11 +43,15 @@ def value_as_json(self) -> str: @property def value_as_dict(self) -> dict[str, Any]: - return dict(self.value) + return verify_dict(obj=self.value) + + @property + def created_at_isoformat(self) -> str | None: + return self.created_at.isoformat() if self.created_at else None - def recalculate_ttl(self) -> None: - if self.expires_at is not None and self.ttl is None: - self.ttl = (self.expires_at - now()).total_seconds() + @property + def expires_at_isoformat(self) -> str | None: + return self.expires_at.isoformat() if self.expires_at else None def to_dict( self, include_metadata: bool = True, include_expiration: bool = True, include_creation: bool = True, stringify_value: bool = False diff --git a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py index 3edf55fa..bebdab77 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py @@ -6,9 +6,23 @@ """ from abc import ABC, abstractmethod -from typing import Any +from datetime import datetime +from typing import Any, Literal, TypeVar -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.errors import DeserializationError +from key_value.shared.utils.managed_entry import ManagedEntry, dump_to_json, load_from_json, verify_dict +from key_value.shared.utils.time_to_live import try_parse_datetime_str + +T = TypeVar("T") + + +def key_must_be(dictionary: dict[str, Any], /, key: str, expected_type: type[T]) -> T | None: + if key not in dictionary: + return None + if not isinstance(dictionary[key], expected_type): + msg = f"{key} must be a {expected_type.__name__}" + raise TypeError(msg) + return dictionary[key] class SerializationAdapter(ABC): @@ -17,36 +31,103 @@ class SerializationAdapter(ABC): Adapters encapsulate the logic for converting between ManagedEntry objects and store-specific storage formats. This provides a consistent interface while allowing each store to optimize its serialization strategy. - - Store implementations should subclass this adapter and define their own - to_storage() and from_storage() methods within their store module. """ + _date_format: Literal["isoformat", "datetime"] | None = "isoformat" + _value_format: Literal["string", "dict"] | None = "dict" + + def __init__( + self, *, date_format: Literal["isoformat", "datetime"] | None = "isoformat", value_format: Literal["string", "dict"] | None = "dict" + ) -> None: + self._date_format = date_format + self._value_format = value_format + + def load_json(self, json_str: str) -> ManagedEntry: + """Convert a JSON string to a dictionary.""" + loaded_data: dict[str, Any] = load_from_json(json_str=json_str) + + return self.load_dict(data=loaded_data) + @abstractmethod - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any] | str: - """Convert a ManagedEntry to the store's storage format. + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + """Prepare data for loading. + + This method is used by subclasses to handle any required transformations before loading the data into a ManagedEntry.""" + + def load_dict(self, data: dict[str, Any]) -> ManagedEntry: + """Convert a dictionary to a ManagedEntry.""" + + data = self.prepare_load(data=data) + + managed_entry_proto: dict[str, Any] = {} - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: Optional collection name. + if self._date_format == "isoformat": + if created_at := key_must_be(data, key="created_at", expected_type=str): + managed_entry_proto["created_at"] = try_parse_datetime_str(value=created_at) + if expires_at := key_must_be(data, key="expires_at", expected_type=str): + managed_entry_proto["expires_at"] = try_parse_datetime_str(value=expires_at) - Returns: - The serialized representation (dict or str depending on store). - """ - ... + if self._date_format == "datetime": + if created_at := key_must_be(data, key="created_at", expected_type=datetime): + managed_entry_proto["created_at"] = created_at + if expires_at := key_must_be(data, key="expires_at", expected_type=datetime): + managed_entry_proto["expires_at"] = expires_at + + if not (value := data.get("value")): + msg = "Value field not found" + raise DeserializationError(message=msg) + + managed_entry_value: dict[str, Any] = {} + + if isinstance(value, str): + managed_entry_value = load_from_json(json_str=value) + elif isinstance(value, dict): + managed_entry_value = verify_dict(obj=value) + else: + msg = "Value field is not a string or dictionary" + raise DeserializationError(message=msg) + + return ManagedEntry( + value=managed_entry_value, + created_at=managed_entry_proto.get("created_at"), + expires_at=managed_entry_proto.get("expires_at"), + ) @abstractmethod - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert stored data back to a ManagedEntry. + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + """Prepare data for dumping to a dictionary. + + This method is used by subclasses to handle any required transformations before dumping the data to a dictionary.""" + + def dump_dict(self, entry: ManagedEntry, exclude_none: bool = True) -> dict[str, Any]: + """Convert a ManagedEntry to a dictionary.""" + + data: dict[str, Any] = { + "value": entry.value_as_dict if self._value_format == "dict" else entry.value_as_json, + "created_at": entry.created_at_isoformat, + "expires_at": entry.expires_at_isoformat, + } + + if exclude_none: + data = {k: v for k, v in data.items() if v is not None} + + return self.prepare_dump(data=data) + + def dump_json(self, entry: ManagedEntry, exclude_none: bool = True) -> str: + """Convert a ManagedEntry to a JSON string.""" + return dump_to_json(obj=self.dump_dict(entry=entry, exclude_none=exclude_none)) + + +class BasicSerializationAdapter(SerializationAdapter): + """Basic serialization adapter that does not perform any transformations.""" - Args: - data: The stored representation to deserialize. + def __init__( + self, *, date_format: Literal["isoformat", "datetime"] | None = "isoformat", value_format: Literal["string", "dict"] | None = "dict" + ) -> None: + super().__init__(date_format=date_format, value_format=value_format) - Returns: - A ManagedEntry reconstructed from storage. + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + return data - Raises: - DeserializationError: If the data cannot be deserialized. - """ - ... + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + return data diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py index 1c02abda..4c73b699 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/base.py @@ -17,6 +17,7 @@ from key_value.shared.errors import StoreSetupError from key_value.shared.type_checking.bear_spray import bear_enforce from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from key_value.shared.utils.time_to_live import prepare_entry_timestamps from typing_extensions import Self, override @@ -73,6 +74,8 @@ class BaseStore(KeyValueProtocol, ABC): _setup_collection_locks: defaultdict[str, Lock] _setup_collection_complete: defaultdict[str, bool] + _serialization_adapter: SerializationAdapter + _seed: FROZEN_SEED_DATA_TYPE default_collection: str @@ -97,6 +100,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP self.default_collection = default_collection or DEFAULT_COLLECTION_NAME + self._serialization_adapter = BasicSerializationAdapter() + if not hasattr(self, "_stable_api"): self._stable_api = False @@ -272,9 +277,9 @@ def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = No collection = collection or self.default_collection self.setup_collection(collection=collection) - (created_at, ttl_seconds, expires_at) = prepare_entry_timestamps(ttl=ttl) + (created_at, _, expires_at) = prepare_entry_timestamps(ttl=ttl) - managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) + managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) self._put_managed_entry(collection=collection, key=key, managed_entry=managed_entry) @@ -293,9 +298,7 @@ def put_many( (created_at, ttl_seconds, expires_at) = prepare_entry_timestamps(ttl=ttl) - managed_entries: list[ManagedEntry] = [ - ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values - ] + managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values] self._put_managed_entries( collection=collection, keys=keys, managed_entries=managed_entries, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py index ee0c6a8b..6fb4c6ba 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py @@ -9,10 +9,10 @@ from elastic_transport import ObjectApiResponse from elastic_transport import SerializationError as ElasticsearchSerializationError from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, LOWERCASE_ALPHABET, NUMBERS, sanitize_string from key_value.shared.utils.serialization import SerializationAdapter -from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str +from key_value.shared.utils.time_to_live import now_as_epoch from typing_extensions import override from key_value.sync.code_gen.stores.base import ( @@ -65,99 +65,50 @@ ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "." -class ElasticsearchAdapter(SerializationAdapter): - """Adapter for Elasticsearch with support for native and string storage modes. +class ElasticsearchSerializationAdapter(SerializationAdapter): + """Adapter for Elasticsearch with support for native and string storage modes.""" - This adapter supports two storage modes: - - Native mode: Stores values as flattened dicts for efficient querying - - String mode: Stores values as JSON strings for backward compatibility - - Elasticsearch-specific features: - - Stores collection name in the document for multi-tenancy - - Uses ISO format for datetime fields - - Supports migration between storage modes - """ + _native_storage: bool def __init__(self, *, native_storage: bool = True) -> None: """Initialize the Elasticsearch adapter. Args: native_storage: If True (default), store values as flattened dicts. - If False, store values as JSON strings. + If False, store values as JSON strings. """ - self.native_storage = native_storage + super().__init__() - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: - """Convert a ManagedEntry to an Elasticsearch document. + self._native_storage = native_storage + self._date_format = "isoformat" + self._value_format = "dict" if native_storage else "string" - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: The collection name to store in the document. + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - Returns: - An Elasticsearch document dict with collection, key, value, and metadata. - """ - document: dict[str, Any] = {"collection": collection or "", "key": key, "value": {}} + data["value"] = {} - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["flattened"] = entry.value_as_dict + if self._native_storage: + data["value"]["flattened"] = value else: - document["value"]["string"] = entry.value_as_json - - if entry.created_at: - document["created_at"] = entry.created_at.isoformat() - if entry.expires_at: - document["expires_at"] = entry.expires_at.isoformat() - - return document + data["value"]["string"] = value - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert an Elasticsearch document back to a ManagedEntry. + return data - This method supports both native (flattened) and string storage modes, - trying the flattened field first and falling back to the string field. - This allows for seamless migration between storage modes. - - Args: - data: The Elasticsearch document to deserialize. - - Returns: - A ManagedEntry reconstructed from the document. + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - Raises: - DeserializationError: If data is not a dict or is malformed. - """ - if not isinstance(data, dict): - msg = "Expected Elasticsearch document to be a dict" - raise DeserializationError(msg) - - document = data - value: dict[str, Any] = {} - - raw_value = document.get("value") - - # Try flattened field first, fall back to string field - if not raw_value or not isinstance(raw_value, dict): - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - value = verify_dict(obj=value_flattened) - elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - if not isinstance(value_str, str): - msg = "Value in `value` field is not a string" - raise DeserializationError(msg) - value = load_from_json(value_str) + if flattened := value.get("flattened"): + data["value"] = flattened + elif string := value.get("string"): + data["value"] = string else: - msg = "Value field not found or invalid type" - raise DeserializationError(msg) - - created_at: datetime | None = try_parse_datetime_str(value=document.get("created_at")) - expires_at: datetime | None = try_parse_datetime_str(value=document.get("expires_at")) + msg = "Value field not found in Elasticsearch document" + raise DeserializationError(message=msg) - return ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) + return data class ElasticsearchStore( @@ -227,7 +178,7 @@ def __init__( self._index_prefix = index_prefix self._native_storage = native_storage self._is_serverless = False - self._adapter = ElasticsearchAdapter(native_storage=native_storage) + self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -275,7 +226,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None try: - return self._adapter.from_storage(data=source) + return self._adapter.load_dict(data=source) except DeserializationError: return None @@ -308,7 +259,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ continue try: - entries_by_id[doc_id] = self._adapter.from_storage(data=source) + entries_by_id[doc_id] = self._adapter.load_dict(data=source) except DeserializationError as e: logger.error( "Failed to deserialize Elasticsearch document in batch operation", @@ -329,10 +280,7 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage index_name: str = self._sanitize_index_name(collection=collection) document_id: str = self._sanitize_document_id(key=key) - document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - if not isinstance(document, dict): - msg = "Elasticsearch adapter must return dict" - raise TypeError(msg) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) try: _ = self._client.index(index=index_name, id=document_id, body=document, refresh=self._should_refresh_on_put) @@ -365,12 +313,10 @@ def _put_managed_entries( index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id) - document: dict[str, Any] = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - if not isinstance(document, dict): - msg = "Elasticsearch adapter must return dict" - raise TypeError(msg) + document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry) operations.extend([index_action, document]) + try: _ = self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType] except ElasticsearchSerializationError as e: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py index 5c1a1055..f996f15f 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py @@ -2,11 +2,11 @@ # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. from collections.abc import Sequence -from datetime import datetime +from datetime import datetime, timezone from typing import Any, overload from key_value.shared.errors import DeserializationError -from key_value.shared.utils.managed_entry import ManagedEntry, verify_dict +from key_value.shared.utils.managed_entry import ManagedEntry from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS, sanitize_string from key_value.shared.utils.serialization import SerializationAdapter from typing_extensions import Self, override @@ -19,7 +19,7 @@ ) try: - from pymongo import MongoClient + from pymongo import MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.database import Database from pymongo.results import DeleteResult # noqa: TC002 @@ -42,117 +42,56 @@ COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_" -class MongoDBAdapter(SerializationAdapter): - """MongoDB-specific serialization adapter with native BSON datetime support. +class MongoDBSerializationAdapter(SerializationAdapter): + """Adapter for MongoDB with support for native and string storage modes.""" - This adapter is optimized for MongoDB, storing: - - Native BSON datetime types for TTL indexing (created_at, expires_at) - - Values in value.object (native BSON) or value.string (JSON) fields - - Support for migration between native and string storage modes - - The native storage mode is recommended for new deployments as it allows - efficient querying of value fields, while string mode provides backward - compatibility with older data. - """ + _native_storage: bool def __init__(self, *, native_storage: bool = True) -> None: - """Initialize the MongoDB adapter. - - Args: - native_storage: If True (default), store value as native BSON dict in value.object field. - If False, store as JSON string in value.string field for backward compatibility. - """ - self.native_storage = native_storage + """Initialize the MongoDB adapter.""" + super().__init__() - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> dict[str, Any]: - """Convert a ManagedEntry to a MongoDB document. + self._native_storage = native_storage + self._date_format = "datetime" + self._value_format = "dict" if native_storage else "string" - Args: - key: The key associated with this entry. - entry: The ManagedEntry to serialize. - collection: The collection (unused, for interface compatibility). - - Returns: - A MongoDB document with key, value, and BSON datetime metadata. - """ - document: dict[str, Any] = {"key": key, "value": {}} + @override + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - # We convert to JSON even if we don't need to, this ensures that the value we were provided - # can be serialized to JSON which helps ensure compatibility across stores - json_str = entry.value_as_json + data["value"] = {} - # Store in appropriate field based on mode - if self.native_storage: - document["value"]["object"] = entry.value_as_dict + if self._native_storage: + data["value"]["object"] = value else: - document["value"]["string"] = json_str - - # Add metadata fields as BSON datetimes for TTL indexing - if entry.created_at: - document["created_at"] = entry.created_at - if entry.expires_at: - document["expires_at"] = entry.expires_at - - return document - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a MongoDB document back to a ManagedEntry. - - This method supports both native (object) and legacy (string) storage modes, - and properly handles BSON datetime types for metadata. - - Args: - data: The MongoDB document to deserialize. + data["value"]["string"] = value - Returns: - A ManagedEntry reconstructed from the document. - - Raises: - DeserializationError: If data is not a dict or is malformed. - """ - if not isinstance(data, dict): - msg = "Expected MongoDB document to be a dict" - raise DeserializationError(msg) - - document = data + return data - if not (value_field := document.get("value")): - msg = "Value field not found" - raise DeserializationError(msg) - - if not isinstance(value_field, dict): - msg = "Expected `value` field to be an object" - raise DeserializationError(msg) - - value_holder: dict[str, Any] = verify_dict(obj=value_field) - - entry_data: dict[str, Any] = {} + @override + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + value = data.pop("value") - # Mongo stores datetimes without timezones as UTC so we mark them as UTC - # Import timezone here to avoid circular import - from key_value.shared.utils.time_to_live import timezone + if value_object := value.get("object"): + data["value"] = value_object + elif value_string := value.get("string"): + data["value"] = value_string + else: + msg = "Value field not found in MongoDB document" + raise DeserializationError(message=msg) - if created_at_datetime := document.get("created_at"): - if not isinstance(created_at_datetime, datetime): + if date_created := data.get("created_at"): + if not isinstance(date_created, datetime): msg = "Expected `created_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["created_at"] = created_at_datetime.replace(tzinfo=timezone.utc) - - if expires_at_datetime := document.get("expires_at"): - if not isinstance(expires_at_datetime, datetime): + raise DeserializationError(message=msg) + data["created_at"] = date_created.replace(tzinfo=timezone.utc) + if date_expires := data.get("expires_at"): + if not isinstance(date_expires, datetime): msg = "Expected `expires_at` field to be a datetime" - raise DeserializationError(msg) - entry_data["expires_at"] = expires_at_datetime.replace(tzinfo=timezone.utc) - - # Support both native (object) and legacy (string) storage - if value_object := value_holder.get("object"): - return ManagedEntry.from_dict(data={"value": value_object, **entry_data}) - - if value_string := value_holder.get("string"): - return ManagedEntry.from_dict(data={"value": value_string, **entry_data}, stringified_value=True) + raise DeserializationError(message=msg) + data["expires_at"] = date_expires.replace(tzinfo=timezone.utc) - msg = "Expected `value` field to be an object with `object` or `string` subfield" - raise DeserializationError(msg) + return data class MongoDBStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore): @@ -239,7 +178,7 @@ def __init__( self._db = self._client[db_name] self._collections_by_name = {} - self._adapter = MongoDBAdapter(native_storage=native_storage) + self._adapter = MongoDBSerializationAdapter(native_storage=native_storage) super().__init__(default_collection=default_collection) @@ -297,7 +236,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if doc := self._collections_by_name[sanitized_collection].find_one(filter={"key": key}): try: - return self._adapter.from_storage(data=doc) + return self._adapter.load_dict(data=doc) except DeserializationError: return None @@ -318,7 +257,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ for doc in cursor: if key := doc.get("key"): try: - managed_entries_by_key[key] = self._adapter.from_storage(data=doc) + managed_entries_by_key[key] = self._adapter.load_dict(data=doc) except DeserializationError: managed_entries_by_key[key] = None @@ -326,15 +265,13 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: - mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(mongo_doc, dict): - msg = "MongoDB adapter must return dict" - raise TypeError(msg) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) sanitized_collection = self._sanitize_collection_name(collection=collection) - _ = self._collections_by_name[sanitized_collection].update_one(filter={"key": key}, update={"$set": mongo_doc}, upsert=True) + _ = self._collections_by_name[sanitized_collection].update_one( + filter={"key": key}, update={"$set": {"collection": collection, "key": key, **mongo_doc}}, upsert=True + ) @override def _put_managed_entries( @@ -352,18 +289,13 @@ def _put_managed_entries( sanitized_collection = self._sanitize_collection_name(collection=collection) - # Use bulk_write for efficient batch operations - from pymongo import UpdateOne - operations: list[UpdateOne] = [] for key, managed_entry in zip(keys, managed_entries, strict=True): - mongo_doc = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(mongo_doc, dict): - msg = "MongoDB adapter must return dict" - raise TypeError(msg) + mongo_doc = self._adapter.dump_dict(entry=managed_entry) - operations.append(UpdateOne(filter={"key": key}, update={"$set": mongo_doc}, upsert=True)) + operations.append( + UpdateOne(filter={"key": key}, update={"$set": {"collection": collection, "key": key, **mongo_doc}}, upsert=True) + ) _ = self._collections_by_name[sanitized_collection].bulk_write(operations) # pyright: ignore[reportUnknownMemberType] diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py index d2f8659a..b20b7154 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/redis/store.py @@ -10,7 +10,7 @@ from key_value.shared.type_checking.bear_spray import bear_spray from key_value.shared.utils.compound import compound_key, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry -from key_value.shared.utils.serialization import SerializationAdapter +from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseEnumerateKeysStore, BaseStore @@ -25,45 +25,6 @@ PAGE_LIMIT = 10000 -class FullJsonAdapter(SerializationAdapter): - """Adapter that serializes entries as complete JSON strings. - - This adapter is suitable for Redis which works with string values. - It serializes the entire ManagedEntry (including all metadata) to a JSON string. - """ - - def to_storage(self, key: str, entry: ManagedEntry, collection: str | None = None) -> str: - """Convert a ManagedEntry to a JSON string. - - Args: - key: The key (unused, for interface compatibility). - entry: The ManagedEntry to serialize. - collection: The collection (unused, for interface compatibility). - - Returns: - A JSON string containing the entry and all metadata. - """ - return entry.to_json(include_metadata=True, include_expiration=True, include_creation=True) - - def from_storage(self, data: dict[str, Any] | str) -> ManagedEntry: - """Convert a JSON string back to a ManagedEntry. - - Args: - data: The JSON string to deserialize. - - Returns: - A ManagedEntry reconstructed from the JSON. - - Raises: - DeserializationError: If data is not a string or cannot be parsed. - """ - if not isinstance(data, str): - msg = "Expected data to be a JSON string" - raise DeserializationError(msg) - - return ManagedEntry.from_json(json_str=data, includes_metadata=True) - - class RedisStore(BaseDestroyStore, BaseEnumerateKeysStore, BaseContextManagerStore, BaseStore): """Redis-based key-value store.""" @@ -119,7 +80,7 @@ def __init__( self._client = Redis(host=host, port=port, db=db, password=password, decode_responses=True) self._stable_api = True - self._adapter = FullJsonAdapter() + self._adapter = BasicSerializationAdapter(date_format="isoformat", value_format="dict") super().__init__(default_collection=default_collection) @@ -133,7 +94,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None try: - return self._adapter.from_storage(data=redis_response) + return self._adapter.load_json(json_str=redis_response) except DeserializationError: return None @@ -150,7 +111,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ for redis_response in redis_responses: if isinstance(redis_response, str): try: - entries.append(self._adapter.from_storage(data=redis_response)) + entries.append(self._adapter.load_json(json_str=redis_response)) except DeserializationError: entries.append(None) else: @@ -162,11 +123,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(json_value, str): - msg = "Redis adapter must return str" - raise TypeError(msg) + json_value: str = self._adapter.dump_json(entry=managed_entry) if managed_entry.ttl is not None: # Redis does not support <= 0 TTLs @@ -194,10 +151,7 @@ def _put_managed_entries( # If there is no TTL, we can just do a simple mset mapping: dict[str, str] = {} for key, managed_entry in zip(keys, managed_entries, strict=True): - json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - if not isinstance(json_value, str): - msg = "Redis adapter must return str" - raise TypeError(msg) + json_value = self._adapter.dump_json(entry=managed_entry) mapping[compound_key(collection=collection, key=key)] = json_value self._client.mset(mapping=mapping) @@ -212,11 +166,7 @@ def _put_managed_entries( for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value = self._adapter.to_storage(key=key, entry=managed_entry, collection=collection) - - if not isinstance(json_value, str): - msg = "Redis adapter must return str" - raise TypeError(msg) + json_value = self._adapter.dump_json(entry=managed_entry) pipeline.setex(name=combo_key, time=ttl_seconds, value=json_value) diff --git a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py index 69851a03..499b194a 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py +++ b/key-value/key-value-sync/tests/code_gen/stores/elasticsearch/test_elasticsearch.py @@ -14,7 +14,7 @@ from key_value.sync.code_gen.stores.base import BaseStore from key_value.sync.code_gen.stores.elasticsearch import ElasticsearchStore -from key_value.sync.code_gen.stores.elasticsearch.store import managed_entry_to_document, source_to_managed_entry +from key_value.sync.code_gen.stores.elasticsearch.store import ElasticsearchSerializationAdapter from tests.code_gen.conftest import docker_container, should_skip_docker_tests from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -56,19 +56,14 @@ def test_managed_entry_document_conversion(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry) + adapter = ElasticsearchSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( - { - "collection": "test_collection", - "key": "test_key", - "value": {"string": '{"test": "test"}'}, - "created_at": "2025-01-01T00:00:00+00:00", - "expires_at": "2025-01-01T00:00:10+00:00", - } + {"value": {"string": '{"test": "test"}'}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00"} ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -81,19 +76,14 @@ def test_managed_entry_document_conversion_native_storage(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(collection="test_collection", key="test_key", managed_entry=managed_entry, native_storage=True) + adapter = ElasticsearchSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( - { - "collection": "test_collection", - "key": "test_key", - "value": {"flattened": {"test": "test"}}, - "created_at": "2025-01-01T00:00:00+00:00", - "expires_at": "2025-01-01T00:00:10+00:00", - } + {"value": {"flattened": {"test": "test"}}, "created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00"} ) - round_trip_managed_entry = source_to_managed_entry(source=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -170,12 +160,7 @@ def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_cl response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"flattened": {"name": "Alice", "age": 30}}, - "created_at": IsStr(min_length=20, max_length=40), - } + {"value": {"flattened": {"name": "Alice", "age": 30}}, "created_at": IsStr(min_length=20, max_length=40)} ) # Test with TTL @@ -183,8 +168,6 @@ def test_value_stored_as_flattened_object(self, store: ElasticsearchStore, es_cl response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"flattened": {"name": "Bob", "age": 25}}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), @@ -223,12 +206,7 @@ def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_client: response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( - { - "collection": "test", - "key": "test_key", - "value": {"string": '{"age": 30, "name": "Alice"}'}, - "created_at": IsStr(min_length=20, max_length=40), - } + {"value": {"string": '{"age": 30, "name": "Alice"}'}, "created_at": IsStr(min_length=20, max_length=40)} ) # Test with TTL @@ -236,8 +214,6 @@ def test_value_stored_as_json_string(self, store: ElasticsearchStore, es_client: response = es_client.get(index=index_name, id=doc_id) assert response.body["_source"] == snapshot( { - "collection": "test", - "key": "test_key", "value": {"string": '{"age": 25, "name": "Bob"}'}, "created_at": IsStr(min_length=20, max_length=40), "expires_at": IsStr(min_length=20, max_length=40), diff --git a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py index a9a3a9a3..6a90c054 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/mongodb/test_mongodb.py @@ -17,7 +17,7 @@ from key_value.sync.code_gen.stores.base import BaseStore from key_value.sync.code_gen.stores.mongodb import MongoDBStore -from key_value.sync.code_gen.stores.mongodb.store import document_to_managed_entry, managed_entry_to_document +from key_value.sync.code_gen.stores.mongodb.store import MongoDBSerializationAdapter from tests.code_gen.conftest import docker_container, should_skip_docker_tests from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -51,7 +51,9 @@ def test_managed_entry_document_conversion_native_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=True) + + adapter = MongoDBSerializationAdapter(native_storage=True) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { @@ -62,7 +64,7 @@ def test_managed_entry_document_conversion_native_mode(): } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at @@ -75,7 +77,8 @@ def test_managed_entry_document_conversion_legacy_mode(): expires_at = created_at + timedelta(seconds=10) managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_document(key="test", managed_entry=managed_entry, native_storage=False) + adapter = MongoDBSerializationAdapter(native_storage=False) + document = adapter.dump_dict(entry=managed_entry) assert document == snapshot( { @@ -86,7 +89,7 @@ def test_managed_entry_document_conversion_legacy_mode(): } ) - round_trip_managed_entry = document_to_managed_entry(document=document) + round_trip_managed_entry = adapter.load_dict(data=document) assert round_trip_managed_entry.value == managed_entry.value assert round_trip_managed_entry.created_at == created_at diff --git a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py index 4620fba8..1bc18625 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py +++ b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py @@ -1,20 +1,19 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_redis.py' # DO NOT CHANGE! Change the original file instead. +import json from collections.abc import Generator -from datetime import datetime, timedelta, timezone +from typing import Any import pytest -from dirty_equals import IsFloat +from dirty_equals import IsDatetime from inline_snapshot import snapshot from key_value.shared.stores.wait import wait_for_true -from key_value.shared.utils.managed_entry import ManagedEntry from redis.client import Redis from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseStore from key_value.sync.code_gen.stores.redis import RedisStore -from key_value.sync.code_gen.stores.redis.store import json_to_managed_entry, managed_entry_to_json from tests.code_gen.conftest import docker_container, should_skip_docker_tests from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -28,25 +27,6 @@ REDIS_VERSIONS_TO_TEST = ["4.0.0", "7.0.0"] -def test_managed_entry_document_conversion(): - created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc) - expires_at = created_at + timedelta(seconds=10) - - managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at) - document = managed_entry_to_json(managed_entry=managed_entry) - - assert document == snapshot( - '{"created_at": "2025-01-01T00:00:00+00:00", "expires_at": "2025-01-01T00:00:10+00:00", "value": {"test": "test"}}' - ) - - round_trip_managed_entry = json_to_managed_entry(json_str=document) - - assert round_trip_managed_entry.value == managed_entry.value - assert round_trip_managed_entry.created_at == created_at - assert round_trip_managed_entry.ttl == IsFloat(lt=0) - assert round_trip_managed_entry.expires_at == expires_at - - def ping_redis() -> bool: client: Redis = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) try: @@ -59,6 +39,10 @@ class RedisFailedToStartError(Exception): pass +def get_client_from_store(store: RedisStore) -> Redis: + return store._client # pyright: ignore[reportPrivateUsage] + + @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not running") class TestRedisStore(ContextManagerStoreTestMixin, BaseStoreTests): @pytest.fixture(autouse=True, scope="session", params=REDIS_VERSIONS_TO_TEST) @@ -78,14 +62,18 @@ def store(self, setup_redis: RedisStore) -> RedisStore: """Create a Redis store for testing.""" # Create the store with test database redis_store = RedisStore(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB) - _ = redis_store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = get_client_from_store(store=redis_store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] return redis_store + @pytest.fixture + def redis_client(self, store: RedisStore) -> Redis: + return get_client_from_store(store=store) + def test_redis_url_connection(self): """Test Redis store creation with URL.""" redis_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}" store = RedisStore(url=redis_url) - _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] store.put(collection="test", key="url_test", value={"test": "value"}) result = store.get(collection="test", key="url_test") assert result == {"test": "value"} @@ -97,11 +85,47 @@ def test_redis_client_connection(self): client = Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) store = RedisStore(client=client) - _ = store._client.flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] + _ = get_client_from_store(store=store).flushdb() # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAny] store.put(collection="test", key="client_test", value={"test": "value"}) result = store.get(collection="test", key="client_test") assert result == {"test": "value"} + def test_redis_document_format(self, store: RedisStore, redis_client: Redis): + """Test Redis store document format.""" + store.put(collection="test", key="document_format_test_1", value={"test_1": "value_1"}) + store.put(collection="test", key="document_format_test_2", value={"test_2": "value_2"}, ttl=10) + + raw_documents = redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) + raw_documents_dicts: list[dict[str, Any]] = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + {"created_at": IsDatetime(iso_string=True), "value": {"test_1": "value_1"}}, + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test_2": "value_2"}}, + ] + ) + + store.put_many( + collection="test", + keys=["document_format_test_3", "document_format_test_4"], + values=[{"test_3": "value_3"}, {"test_4": "value_4"}], + ttl=10, + ) + raw_documents = redis_client.mget(keys=["test::document_format_test_3", "test::document_format_test_4"]) + raw_documents_dicts = [json.loads(raw_document) for raw_document in raw_documents] + assert raw_documents_dicts == snapshot( + [ + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test_3": "value_3"}}, + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test_4": "value_4"}}, + ] + ) + + store.put(collection="test", key="document_format_test", value={"test": "value"}, ttl=10) + raw_document = redis_client.get(name="test::document_format_test") + raw_document_dict = json.loads(raw_document) + assert raw_document_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test": "value"}} + ) + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override def test_not_unbounded(self, store: BaseStore): ... From 17743fe986260a5b2c3861263ce1ddfc222b30a0 Mon Sep 17 00:00:00 2001 From: William Easton Date: Wed, 29 Oct 2025 21:01:01 -0500 Subject: [PATCH 05/11] Updates for tests --- .../src/key_value/shared/utils/managed_entry.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index 9128610d..0d100651 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -7,7 +7,7 @@ from typing_extensions import Self from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.time_to_live import now, prepare_ttl, seconds_to, try_parse_datetime_str +from key_value.shared.utils.time_to_live import now, seconds_to, try_parse_datetime_str @dataclass(kw_only=True) @@ -82,7 +82,11 @@ def to_json( @classmethod def from_dict( # noqa: PLR0912 - cls, data: dict[str, Any], includes_metadata: bool = True, ttl: SupportsFloat | None = None, stringified_value: bool = False + cls, + data: dict[str, Any], + includes_metadata: bool = True, + ttl: SupportsFloat | None = None, + stringified_value: bool = False, ) -> Self: if not includes_metadata: return cls( @@ -127,12 +131,9 @@ def from_dict( # noqa: PLR0912 raise DeserializationError(msg) value = verify_dict(obj=raw_value) - ttl_seconds: float | None = prepare_ttl(t=ttl) - return cls( created_at=created_at, expires_at=expires_at, - ttl=ttl_seconds, value=value, ) From 0dbf7d61892df05e73c8f4d779e855b41cbcd15f Mon Sep 17 00:00:00 2001 From: William Easton Date: Wed, 29 Oct 2025 21:11:42 -0500 Subject: [PATCH 06/11] More clean-up --- .../src/key_value/aio/stores/disk/multi_store.py | 12 +++++++----- .../src/key_value/aio/stores/disk/store.py | 9 +++++---- .../src/key_value/shared/utils/managed_entry.py | 15 +++++++++++---- .../sync/code_gen/stores/disk/multi_store.py | 12 +++++++----- .../key_value/sync/code_gen/stores/disk/store.py | 9 +++++---- 5 files changed, 35 insertions(+), 22 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py index 5ff14b9d..5ef4645a 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py @@ -1,10 +1,9 @@ -import time from collections.abc import Callable from pathlib import Path from typing import overload from key_value.shared.utils.compound import compound_key -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, datetime from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseStore @@ -118,9 +117,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not isinstance(managed_entry_str, str): return None - ttl = (expire_epoch - time.time()) if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch) return managed_entry @@ -134,7 +134,9 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - _ = self._cache[collection].set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache[collection].set( + key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry, exclude_none=False), expire=managed_entry.ttl + ) @override async def _delete_managed_entry(self, *, key: str, collection: str) -> bool: diff --git a/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py b/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py index 8ae36cf6..f9a00f0b 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/disk/store.py @@ -1,4 +1,4 @@ -import time +from datetime import datetime, timezone from pathlib import Path from typing import overload @@ -90,9 +90,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not isinstance(managed_entry_str, str): return None - ttl = (expire_epoch - time.time()) if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry @@ -106,7 +107,7 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - _ = self._cache.set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override async def _delete_managed_entry(self, *, key: str, collection: str) -> bool: diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index 0d100651..fbf7ee28 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -7,7 +7,7 @@ from typing_extensions import Self from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.time_to_live import now, seconds_to, try_parse_datetime_str +from key_value.shared.utils.time_to_live import now, now_plus, seconds_to, try_parse_datetime_str @dataclass(kw_only=True) @@ -53,6 +53,14 @@ def created_at_isoformat(self) -> str | None: def expires_at_isoformat(self) -> str | None: return self.expires_at.isoformat() if self.expires_at else None + @classmethod + def from_ttl(cls, *, value: Mapping[str, Any], created_at: datetime | None = None, ttl: SupportsFloat) -> Self: + return cls( + value=value, + created_at=created_at, + expires_at=(now_plus(seconds=ttl) if created_at else None), + ) + def to_dict( self, include_metadata: bool = True, include_expiration: bool = True, include_creation: bool = True, stringify_value: bool = False ) -> dict[str, Any]: @@ -85,7 +93,6 @@ def from_dict( # noqa: PLR0912 cls, data: dict[str, Any], includes_metadata: bool = True, - ttl: SupportsFloat | None = None, stringified_value: bool = False, ) -> Self: if not includes_metadata: @@ -138,10 +145,10 @@ def from_dict( # noqa: PLR0912 ) @classmethod - def from_json(cls, json_str: str, includes_metadata: bool = True, ttl: SupportsFloat | None = None) -> Self: + def from_json(cls, json_str: str, includes_metadata: bool = True) -> Self: data: dict[str, Any] = load_from_json(json_str=json_str) - return cls.from_dict(data=data, includes_metadata=includes_metadata, ttl=ttl) + return cls.from_dict(data=data, includes_metadata=includes_metadata) def dump_to_json(obj: dict[str, Any]) -> str: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py index 2c9fafaa..3245a7c0 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py @@ -1,13 +1,12 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'multi_store.py' # DO NOT CHANGE! Change the original file instead. -import time from collections.abc import Callable from pathlib import Path from typing import overload from key_value.shared.utils.compound import compound_key -from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.managed_entry import ManagedEntry, datetime from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseContextManagerStore, BaseStore @@ -121,9 +120,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if not isinstance(managed_entry_str, str): return None - ttl = expire_epoch - time.time() if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch) return managed_entry @@ -131,7 +131,9 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - _ = self._cache[collection].set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache[collection].set( + key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry, exclude_none=False), expire=managed_entry.ttl + ) @override def _delete_managed_entry(self, *, key: str, collection: str) -> bool: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py index d75b32c9..ceabfa5d 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/store.py @@ -1,7 +1,7 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. -import time +from datetime import datetime, timezone from pathlib import Path from typing import overload @@ -93,9 +93,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if not isinstance(managed_entry_str, str): return None - ttl = expire_epoch - time.time() if expire_epoch else None + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl) + if expire_epoch: + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry @@ -103,7 +104,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - _ = self._cache.set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl) + _ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override def _delete_managed_entry(self, *, key: str, collection: str) -> bool: From 8c4b6d3283d17605bd218792dca99d4665cfb66b Mon Sep 17 00:00:00 2001 From: William Easton Date: Thu, 30 Oct 2025 15:47:34 -0500 Subject: [PATCH 07/11] PR Cleanup --- .../src/key_value/aio/stores/disk/multi_store.py | 3 ++- .../src/key_value/aio/stores/elasticsearch/store.py | 8 ++++---- .../src/key_value/aio/stores/mongodb/store.py | 8 ++++---- .../src/key_value/shared/utils/managed_entry.py | 8 ++++---- .../key_value/sync/code_gen/stores/disk/multi_store.py | 3 ++- .../key_value/sync/code_gen/stores/elasticsearch/store.py | 8 ++++---- .../src/key_value/sync/code_gen/stores/mongodb/store.py | 8 ++++---- 7 files changed, 24 insertions(+), 22 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py index 5ef4645a..4adb0670 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from datetime import timezone from pathlib import Path from typing import overload @@ -120,7 +121,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) if expire_epoch: - managed_entry.expires_at = datetime.fromtimestamp(expire_epoch) + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry diff --git a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py index c395279e..fbe7bfc2 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/elasticsearch/store.py @@ -120,10 +120,10 @@ def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: value = data.pop("value") - if flattened := value.get("flattened"): - data["value"] = flattened - elif string := value.get("string"): - data["value"] = string + if "flattened" in value: + data["value"] = value["flattened"] + elif "string" in value: + data["value"] = value["string"] else: msg = "Value field not found in Elasticsearch document" raise DeserializationError(message=msg) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py index 34d6e2e1..17f41592 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py @@ -65,10 +65,10 @@ def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: value = data.pop("value") - if value_object := value.get("object"): - data["value"] = value_object - elif value_string := value.get("string"): - data["value"] = value_string + if "object" in value: + data["value"] = value["object"] + elif "string" in value: + data["value"] = value["string"] else: msg = "Value field not found in MongoDB document" raise DeserializationError(message=msg) diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index fbf7ee28..4bd8bcca 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime -from typing import Any, SupportsFloat, cast +from typing import Any, SupportsFloat from typing_extensions import Self @@ -58,7 +58,7 @@ def from_ttl(cls, *, value: Mapping[str, Any], created_at: datetime | None = Non return cls( value=value, created_at=created_at, - expires_at=(now_plus(seconds=ttl) if created_at else None), + expires_at=(now_plus(seconds=float(ttl)) if ttl else None), ) def to_dict( @@ -169,7 +169,7 @@ def load_from_json(json_str: str) -> dict[str, Any]: def verify_dict(obj: Any) -> dict[str, Any]: - if not isinstance(obj, dict): + if not isinstance(obj, Mapping): msg = "Object is not a dictionary" raise DeserializationError(msg) @@ -177,7 +177,7 @@ def verify_dict(obj: Any) -> dict[str, Any]: msg = "Object contains non-string keys" raise DeserializationError(msg) - return cast(typ="dict[str, Any]", val=obj) + return dict(obj) # pyright: ignore[reportUnknownArgumentType] def estimate_serialized_size(value: Mapping[str, Any]) -> int: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py index 3245a7c0..325b5c30 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py @@ -2,6 +2,7 @@ # from the original file 'multi_store.py' # DO NOT CHANGE! Change the original file instead. from collections.abc import Callable +from datetime import timezone from pathlib import Path from typing import overload @@ -123,7 +124,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) if expire_epoch: - managed_entry.expires_at = datetime.fromtimestamp(expire_epoch) + managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc) return managed_entry diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py index 6fb4c6ba..2bc36453 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/elasticsearch/store.py @@ -100,10 +100,10 @@ def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: value = data.pop("value") - if flattened := value.get("flattened"): - data["value"] = flattened - elif string := value.get("string"): - data["value"] = string + if "flattened" in value: + data["value"] = value["flattened"] + elif "string" in value: + data["value"] = value["string"] else: msg = "Value field not found in Elasticsearch document" raise DeserializationError(message=msg) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py index f996f15f..af3fa326 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/mongodb/store.py @@ -72,10 +72,10 @@ def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: value = data.pop("value") - if value_object := value.get("object"): - data["value"] = value_object - elif value_string := value.get("string"): - data["value"] = value_string + if "object" in value: + data["value"] = value["object"] + elif "string" in value: + data["value"] = value["string"] else: msg = "Value field not found in MongoDB document" raise DeserializationError(message=msg) From 59cbaee9fbd7172e78578d6d1f74844bdfc5960e Mon Sep 17 00:00:00 2001 From: William Easton Date: Thu, 30 Oct 2025 15:52:29 -0500 Subject: [PATCH 08/11] fix redis type checks --- key-value/key-value-aio/tests/stores/redis/test_redis.py | 4 ++-- .../key-value-sync/tests/code_gen/stores/redis/test_redis.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/key-value/key-value-aio/tests/stores/redis/test_redis.py b/key-value/key-value-aio/tests/stores/redis/test_redis.py index c7e7c6bb..f75e228a 100644 --- a/key-value/key-value-aio/tests/stores/redis/test_redis.py +++ b/key-value/key-value-aio/tests/stores/redis/test_redis.py @@ -95,7 +95,7 @@ async def test_redis_document_format(self, store: RedisStore, redis_client: Redi await store.put(collection="test", key="document_format_test_1", value={"test_1": "value_1"}) await store.put(collection="test", key="document_format_test_2", value={"test_2": "value_2"}, ttl=10) - raw_documents = await redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) + raw_documents: Any = await redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) raw_documents_dicts: list[dict[str, Any]] = [json.loads(raw_document) for raw_document in raw_documents] assert raw_documents_dicts == snapshot( [ @@ -135,7 +135,7 @@ async def test_redis_document_format(self, store: RedisStore, redis_client: Redi ) await store.put(collection="test", key="document_format_test", value={"test": "value"}, ttl=10) - raw_document = await redis_client.get(name="test::document_format_test") + raw_document: Any = await redis_client.get(name="test::document_format_test") raw_document_dict = json.loads(raw_document) assert raw_document_dict == snapshot( {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test": "value"}} diff --git a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py index 1bc18625..dd5841e8 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py +++ b/key-value/key-value-sync/tests/code_gen/stores/redis/test_redis.py @@ -95,7 +95,7 @@ def test_redis_document_format(self, store: RedisStore, redis_client: Redis): store.put(collection="test", key="document_format_test_1", value={"test_1": "value_1"}) store.put(collection="test", key="document_format_test_2", value={"test_2": "value_2"}, ttl=10) - raw_documents = redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) + raw_documents: Any = redis_client.mget(keys=["test::document_format_test_1", "test::document_format_test_2"]) raw_documents_dicts: list[dict[str, Any]] = [json.loads(raw_document) for raw_document in raw_documents] assert raw_documents_dicts == snapshot( [ @@ -120,7 +120,7 @@ def test_redis_document_format(self, store: RedisStore, redis_client: Redis): ) store.put(collection="test", key="document_format_test", value={"test": "value"}, ttl=10) - raw_document = redis_client.get(name="test::document_format_test") + raw_document: Any = redis_client.get(name="test::document_format_test") raw_document_dict = json.loads(raw_document) assert raw_document_dict == snapshot( {"created_at": IsDatetime(iso_string=True), "expires_at": IsDatetime(iso_string=True), "value": {"test": "value"}} From 2cb6d1aabdf8daca591442a2e6e10ede02a4dcb5 Mon Sep 17 00:00:00 2001 From: William Easton Date: Thu, 30 Oct 2025 18:21:38 -0500 Subject: [PATCH 09/11] Finish refactor --- .../key_value/aio/stores/disk/multi_store.py | 17 +--- .../key_value/aio/stores/dynamodb/store.py | 4 +- .../src/key_value/aio/stores/keyring/store.py | 4 +- .../key_value/aio/stores/memcached/store.py | 11 ++- .../src/key_value/aio/stores/memory/store.py | 36 +++---- .../src/key_value/aio/stores/rocksdb/store.py | 4 +- .../src/key_value/aio/stores/simple/store.py | 21 ++-- .../src/key_value/aio/stores/valkey/store.py | 4 +- .../src/key_value/aio/stores/vault/store.py | 4 +- .../aio/stores/windows_registry/store.py | 4 +- .../tests/stores/disk/test_disk.py | 23 +++++ .../tests/stores/disk/test_multi_disk.py | 28 ++++++ .../tests/stores/dynamodb/test_dynamodb.py | 33 +++++++ .../tests/stores/memcached/test_memcached.py | 28 ++++++ .../tests/stores/rocksdb/test_rocksdb.py | 25 +++++ .../tests/stores/valkey/test_valkey.py | 26 +++++ .../tests/stores/wrappers/test_limit_size.py | 7 +- .../key_value/shared/utils/managed_entry.py | 95 +------------------ .../key_value/shared/utils/serialization.py | 15 +++ .../sync/code_gen/stores/disk/multi_store.py | 17 +--- .../sync/code_gen/stores/keyring/store.py | 4 +- .../sync/code_gen/stores/memory/store.py | 39 ++++---- .../sync/code_gen/stores/rocksdb/store.py | 4 +- .../sync/code_gen/stores/simple/store.py | 21 ++-- .../sync/code_gen/stores/valkey/store.py | 4 +- .../sync/code_gen/stores/vault/store.py | 4 +- .../code_gen/stores/windows_registry/store.py | 4 +- .../tests/code_gen/stores/disk/test_disk.py | 23 +++++ .../code_gen/stores/disk/test_multi_disk.py | 23 +++++ .../code_gen/stores/rocksdb/test_rocksdb.py | 25 +++++ .../code_gen/stores/valkey/test_valkey.py | 26 +++++ .../stores/wrappers/test_limit_size.py | 7 +- 32 files changed, 385 insertions(+), 205 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py index 4adb0670..4c3ee1ed 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/disk/multi_store.py @@ -3,8 +3,8 @@ from pathlib import Path from typing import overload -from key_value.shared.utils.compound import compound_key from key_value.shared.utils.managed_entry import ManagedEntry, datetime +from key_value.shared.utils.serialization import BasicSerializationAdapter from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseStore @@ -100,6 +100,7 @@ def default_disk_cache_factory(collection: str) -> Cache: self._cache = {} self._stable_api = True + self._serialization_adapter = BasicSerializationAdapter() super().__init__(default_collection=default_collection) @@ -109,11 +110,9 @@ async def _setup_collection(self, *, collection: str) -> None: @override async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: - combo_key: str = compound_key(collection=collection, key=key) - expire_epoch: float - managed_entry_str, expire_epoch = self._cache[collection].get(key=combo_key, expire_time=True) # pyright: ignore[reportAny] + managed_entry_str, expire_epoch = self._cache[collection].get(key=key, expire_time=True) # pyright: ignore[reportAny] if not isinstance(managed_entry_str, str): return None @@ -133,17 +132,11 @@ async def _put_managed_entry( collection: str, managed_entry: ManagedEntry, ) -> None: - combo_key: str = compound_key(collection=collection, key=key) - - _ = self._cache[collection].set( - key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry, exclude_none=False), expire=managed_entry.ttl - ) + _ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override async def _delete_managed_entry(self, *, key: str, collection: str) -> bool: - combo_key: str = compound_key(collection=collection, key=key) - - return self._cache[collection].delete(key=combo_key, retry=True) + return self._cache[collection].delete(key=key, retry=True) def _sync_close(self) -> None: for cache in self._cache.values(): diff --git a/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py index e2c87fb9..a7754253 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py @@ -199,7 +199,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not json_value: return None - return ManagedEntry.from_json(json_str=json_value) # pyright: ignore[reportUnknownArgumentType] + return self._serialization_adapter.load_json(json_str=json_value) @override async def _put_managed_entry( @@ -210,7 +210,7 @@ async def _put_managed_entry( managed_entry: ManagedEntry, ) -> None: """Store a managed entry in DynamoDB.""" - json_value = managed_entry.to_json() + json_value = self._serialization_adapter.dump_json(entry=managed_entry) item: dict[str, Any] = { "collection": {"S": collection}, diff --git a/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py b/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py index cec5b268..3967d3d5 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/keyring/store.py @@ -79,7 +79,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if json_str is None: return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -88,7 +88,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) keyring.set_password(service_name=self._service_name, username=combo_key, password=json_str) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py b/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py index 5197c145..26c0054d 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py @@ -4,6 +4,7 @@ from key_value.shared.utils.compound import compound_key from key_value.shared.utils.managed_entry import ManagedEntry +from key_value.shared.utils.serialization import BasicSerializationAdapter from typing_extensions import override from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseStore @@ -44,9 +45,11 @@ def __init__( port: Memcached port. Defaults to 11211. default_collection: The default collection to use if no collection is provided. """ + super().__init__(default_collection=default_collection) + self._client = client or Client(host=host, port=port) - super().__init__(default_collection=default_collection) + self._serialization_adapter = BasicSerializationAdapter(value_format="dict") def sanitize_key(self, key: str) -> str: if len(key) > MAX_KEY_LENGTH: @@ -65,7 +68,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry json_str: str = raw_value.decode(encoding="utf-8") - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: @@ -82,7 +85,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> for raw_value in raw_values: if isinstance(raw_value, (bytes, bytearray)): json_str: str = raw_value.decode(encoding="utf-8") - entries.append(ManagedEntry.from_json(json_str=json_str)) + entries.append(self._serialization_adapter.load_json(json_str=json_str)) else: entries.append(None) @@ -106,7 +109,7 @@ async def _put_managed_entry( else: exptime = max(int(managed_entry.ttl), 1) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) _ = await self._client.set( key=combo_key.encode(encoding="utf-8"), diff --git a/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py b/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py index 07caf85a..62c05ecc 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/memory/store.py @@ -4,7 +4,8 @@ from typing import Any from key_value.shared.utils.managed_entry import ManagedEntry -from typing_extensions import Self, override +from key_value.shared.utils.serialization import BasicSerializationAdapter +from typing_extensions import override from key_value.aio.stores.base import ( SEED_DATA_TYPE, @@ -29,16 +30,6 @@ class MemoryCacheEntry: expires_at: datetime | None - @classmethod - def from_managed_entry(cls, managed_entry: ManagedEntry) -> Self: - return cls( - json_str=managed_entry.to_json(), - expires_at=managed_entry.expires_at, - ) - - def to_managed_entry(self) -> ManagedEntry: - return ManagedEntry.from_json(json_str=self.json_str) - def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, _now: float) -> float: """Calculate time-to-use for cache entries based on their expiration time.""" @@ -53,8 +44,6 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: # noqa: ARG001 return 1 -DEFAULT_MAX_ENTRIES_PER_COLLECTION = 10000 - DEFAULT_PAGE_SIZE = 10000 PAGE_LIMIT = 10000 @@ -62,30 +51,33 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: # noqa: ARG001 class MemoryCollection: _cache: TLRUCache[str, MemoryCacheEntry] - def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION): + def __init__(self, max_entries: int | None = None): """Initialize a fixed-size in-memory collection. Args: - max_entries: The maximum number of entries per collection. Defaults to 10,000 entries. + max_entries: The maximum number of entries per collection. Defaults to no limit. """ self._cache = TLRUCache[str, MemoryCacheEntry]( - maxsize=max_entries, + maxsize=max_entries or sys.maxsize, ttu=_memory_cache_ttu, getsizeof=_memory_cache_getsizeof, ) + self._serialization_adapter = BasicSerializationAdapter() + def get(self, key: str) -> ManagedEntry | None: managed_entry_str: MemoryCacheEntry | None = self._cache.get(key) if managed_entry_str is None: return None - managed_entry: ManagedEntry = managed_entry_str.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str.json_str) return managed_entry def put(self, key: str, value: ManagedEntry) -> None: - self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value) + json_str: str = self._serialization_adapter.dump_json(entry=value) + self._cache[key] = MemoryCacheEntry(json_str=json_str, expires_at=value.expires_at) def delete(self, key: str) -> bool: return self._cache.pop(key, None) is not None @@ -105,26 +97,28 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol def __init__( self, *, - max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, + max_entries_per_collection: int | None = None, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None, ): """Initialize a fixed-size in-memory store. Args: - max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000. + max_entries_per_collection: The maximum number of entries per collection. Defaults to no limit. default_collection: The default collection to use if no collection is provided. seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. Each value must be a mapping (dict) that will be stored as the entry's value. Seeding occurs lazily when each collection is first accessed. """ - self.max_entries_per_collection = max_entries_per_collection + self.max_entries_per_collection = max_entries_per_collection or sys.maxsize self._cache = {} self._stable_api = True + self._serialization_adapter = BasicSerializationAdapter() + super().__init__(default_collection=default_collection, seed=seed) @override diff --git a/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py index 61b98f68..0a6542b6 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py @@ -124,7 +124,7 @@ async def _put_managed_entry( self._fail_on_closed_store() combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) self._db[combo_key] = json_value.encode("utf-8") @@ -147,7 +147,7 @@ async def _put_managed_entries( batch = WriteBatch() for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) batch.put(combo_key, json_value.encode("utf-8")) self._db.write(batch) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py b/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py index 98ce2df9..ce93cb9c 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/simple/store.py @@ -1,9 +1,11 @@ +import sys from collections import defaultdict from dataclasses import dataclass from datetime import datetime from key_value.shared.utils.compound import compound_key, get_collections_from_compound_keys, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json +from key_value.shared.utils.serialization import ValueOnlySerializationAdapter from key_value.shared.utils.time_to_live import seconds_to from typing_extensions import override @@ -48,19 +50,21 @@ class SimpleStore(BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDes _data: dict[str, SimpleStoreEntry] - def __init__(self, max_entries: int = DEFAULT_SIMPLE_STORE_MAX_ENTRIES, default_collection: str | None = None): + def __init__(self, max_entries: int | None = None, default_collection: str | None = None): """Initialize the simple store. Args: - max_entries: The maximum number of entries to store. Defaults to 10000. + max_entries: The maximum number of entries to store. Defaults to no limit. default_collection: The default collection to use if no collection is provided. """ - self.max_entries = max_entries + super().__init__(default_collection=default_collection) + + self.max_entries = max_entries or sys.maxsize self._data = defaultdict[str, SimpleStoreEntry]() - super().__init__(default_collection=default_collection) + self._serialization_adapter = ValueOnlySerializationAdapter() @override async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: @@ -71,7 +75,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if store_entry is None: return None - return store_entry.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=store_entry.json_str) + managed_entry.expires_at = store_entry.expires_at + managed_entry.created_at = store_entry.created_at + return managed_entry @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -81,7 +88,9 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: _ = self._data.pop(next(iter(self._data))) self._data[combo_key] = SimpleStoreEntry( - json_str=managed_entry.to_json(include_metadata=False), expires_at=managed_entry.expires_at, created_at=managed_entry.created_at + json_str=self._serialization_adapter.dump_json(entry=managed_entry), + expires_at=managed_entry.expires_at, + created_at=managed_entry.created_at, ) @override diff --git a/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py b/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py index aa2530a7..57398ee1 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py @@ -110,7 +110,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> for response in responses: if isinstance(response, bytes): decoded_response: str = response.decode("utf-8") - entries.append(ManagedEntry.from_json(json_str=decoded_response)) + entries.append(self._serialization_adapter.load_json(json_str=decoded_response)) else: entries.append(None) @@ -126,7 +126,7 @@ async def _put_managed_entry( ) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) expiry: ExpirySet | None = ExpirySet(expiry_type=ExpiryType.SEC, value=int(managed_entry.ttl)) if managed_entry.ttl else None diff --git a/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py b/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py index cfc86afe..772c9fff 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/vault/store.py @@ -102,13 +102,13 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None json_str: str = secret_data["value"] # pyright: ignore[reportUnknownVariableType] - return ManagedEntry.from_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] + return self._serialization_adapter.load_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) # Store the JSON string in a 'value' field secret_data = {"value": json_str} diff --git a/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py b/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py index 1ce3dcc5..1a3c4add 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/windows_registry/store.py @@ -88,14 +88,14 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry if not (json_str := get_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key)): return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: sanitized_key = self._sanitize_key(key=key) registry_path = self._get_registry_path(collection=collection) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) set_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key, value=json_str) diff --git a/key-value/key-value-aio/tests/stores/disk/test_disk.py b/key-value/key-value-aio/tests/stores/disk/test_disk.py index 2aaf7ceb..e2c3453c 100644 --- a/key-value/key-value-aio/tests/stores/disk/test_disk.py +++ b/key-value/key-value-aio/tests/stores/disk/test_disk.py @@ -1,7 +1,11 @@ +import json import tempfile from collections.abc import AsyncGenerator import pytest +from dirty_equals import IsDatetime +from diskcache.core import Cache +from inline_snapshot import snapshot from typing_extensions import override from key_value.aio.stores.disk import DiskStore @@ -22,3 +26,22 @@ async def store(self, disk_store: DiskStore) -> DiskStore: disk_store._cache.clear() # pyright: ignore[reportPrivateUsage] return disk_store + + @pytest.fixture + async def disk_cache(self, disk_store: DiskStore) -> Cache: + return disk_store._cache # pyright: ignore[reportPrivateUsage] + + async def test_value_stored(self, store: DiskStore, disk_cache: Cache): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py b/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py index 09d226e3..832eeccb 100644 --- a/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py +++ b/key-value/key-value-aio/tests/stores/disk/test_multi_disk.py @@ -1,13 +1,20 @@ +import json import tempfile from collections.abc import AsyncGenerator from pathlib import Path +from typing import TYPE_CHECKING import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from typing_extensions import override from key_value.aio.stores.disk.multi_store import MultiDiskStore from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin +if TYPE_CHECKING: + from diskcache.core import Cache + TEST_SIZE_LIMIT = 100 * 1024 # 100KB @@ -24,3 +31,24 @@ async def store(self, multi_disk_store: MultiDiskStore) -> MultiDiskStore: multi_disk_store._cache[collection].clear() # pyright: ignore[reportPrivateUsage] return multi_disk_store + + async def test_value_stored(self, store: MultiDiskStore): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + disk_cache: Cache = store._cache["test"] # pyright: ignore[reportPrivateUsage] + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + { + "value": {"name": "Alice", "age": 30}, + "created_at": IsDatetime(iso_string=True), + } + ) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py b/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py index e0af2a47..d71d818d 100644 --- a/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py +++ b/key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py @@ -1,8 +1,14 @@ import contextlib +import json from collections.abc import AsyncGenerator +from typing import Any import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true +from types_aiobotocore_dynamodb.client import DynamoDBClient +from types_aiobotocore_dynamodb.type_defs import GetItemOutputTypeDef from typing_extensions import override from key_value.aio.stores.base import BaseStore @@ -48,6 +54,10 @@ class DynamoDBFailedToStartError(Exception): pass +def get_value_from_response(response: GetItemOutputTypeDef) -> dict[str, Any]: + return json.loads(response.get("Item", {}).get("value", {}).get("S", {})) # pyright: ignore[reportArgumentType] + + @pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available") class TestDynamoDBStore(ContextManagerStoreTestMixin, BaseStoreTests): @pytest.fixture(autouse=True, scope="session", params=DYNAMODB_VERSIONS_TO_TEST) @@ -98,6 +108,29 @@ async def store(self, setup_dynamodb: None) -> DynamoDBStore: async def dynamodb_store(self, store: DynamoDBStore) -> DynamoDBStore: return store + @pytest.fixture + async def dynamodb_client(self, store: DynamoDBStore) -> DynamoDBClient: + return store._connected_client # pyright: ignore[reportPrivateUsage] + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + async def test_value_stored(self, store: DynamoDBStore, dynamodb_client: DynamoDBClient): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + response = await dynamodb_client.get_item( + TableName=DYNAMODB_TEST_TABLE, Key={"collection": {"S": "test"}, "key": {"S": "test_key"}} + ) + assert get_value_from_response(response=response) == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}} + ) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + response = await dynamodb_client.get_item( + TableName=DYNAMODB_TEST_TABLE, Key={"collection": {"S": "test"}, "key": {"S": "test_key"}} + ) + assert get_value_from_response(response=response) == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/memcached/test_memcached.py b/key-value/key-value-aio/tests/stores/memcached/test_memcached.py index d8089848..2461de3c 100644 --- a/key-value/key-value-aio/tests/stores/memcached/test_memcached.py +++ b/key-value/key-value-aio/tests/stores/memcached/test_memcached.py @@ -1,8 +1,11 @@ import contextlib +import json from collections.abc import AsyncGenerator import pytest from aiomcache import Client +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true from typing_extensions import override @@ -64,3 +67,28 @@ async def store(self, setup_memcached: None) -> MemcachedStore: @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + @pytest.fixture + async def memcached_client(self, store: MemcachedStore) -> Client: + return store._client # pyright: ignore[reportPrivateUsage] + + async def test_value_stored(self, store: MemcachedStore, memcached_client: Client): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = await memcached_client.get(key=b"test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = await memcached_client.get(key=b"test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + { + "created_at": IsDatetime(iso_string=True), + "expires_at": IsDatetime(iso_string=True), + "value": {"age": 30, "name": "Alice"}, + } + ) diff --git a/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py b/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py index e3f37694..2aee7837 100644 --- a/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py +++ b/key-value/key-value-aio/tests/stores/rocksdb/test_rocksdb.py @@ -1,8 +1,12 @@ +import json from collections.abc import AsyncGenerator from pathlib import Path from tempfile import TemporaryDirectory import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot +from rocksdict import Rdict from typing_extensions import override from key_value.aio.stores.base import BaseStore @@ -59,3 +63,24 @@ async def test_rocksdb_db_connection(self): @pytest.mark.skip(reason="Local disk stores are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + @pytest.fixture + async def rocksdb_client(self, store: RocksDBStore) -> Rdict: + return store._db # pyright: ignore[reportPrivateUsage] + + async def test_value_stored(self, store: RocksDBStore, rocksdb_client: Rdict): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/valkey/test_valkey.py b/key-value/key-value-aio/tests/stores/valkey/test_valkey.py index 9dc6895b..e187e78f 100644 --- a/key-value/key-value-aio/tests/stores/valkey/test_valkey.py +++ b/key-value/key-value-aio/tests/stores/valkey/test_valkey.py @@ -1,11 +1,16 @@ import contextlib +import json from collections.abc import AsyncGenerator import pytest +from dirty_equals import IsDatetime +from glide.glide_client import BaseClient +from inline_snapshot import snapshot from key_value.shared.stores.wait import async_wait_for_true from typing_extensions import override from key_value.aio.stores.base import BaseStore +from key_value.aio.stores.valkey import ValkeyStore from tests.conftest import detect_on_windows, docker_container, should_skip_docker_tests from tests.stores.base import ( BaseStoreTests, @@ -81,6 +86,27 @@ async def store(self, setup_valkey: None): return store + @pytest.fixture + async def valkey_client(self, store: ValkeyStore): + return store._connected_client # pyright: ignore[reportPrivateUsage] + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override async def test_not_unbounded(self, store: BaseStore): ... + + async def test_value_stored(self, store: ValkeyStore, valkey_client: BaseClient): + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = await valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = await valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py b/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py index 8585ceb0..cd4b8ec0 100644 --- a/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py +++ b/key-value/key-value-aio/tests/stores/wrappers/test_limit_size.py @@ -1,5 +1,6 @@ import pytest from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.aio.stores.memory.store import MemoryStore @@ -137,12 +138,8 @@ async def test_put_many_all_too_large_without_raise(self, memory_store: MemorySt async def test_exact_size_limit(self, memory_store: MemoryStore): # First, determine the exact size of a small value - from key_value.shared.utils.managed_entry import ManagedEntry - test_value = {"test": "value"} - managed_entry = ManagedEntry(value=test_value) - json_str = managed_entry.to_json() - exact_size = len(json_str.encode("utf-8")) + exact_size = estimate_serialized_size(value=test_value) # Create a store with exact size limit limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size, raise_on_too_large=True) diff --git a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py index 4bd8bcca..b947308d 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/managed_entry.py @@ -7,7 +7,8 @@ from typing_extensions import Self from key_value.shared.errors import DeserializationError, SerializationError -from key_value.shared.utils.time_to_live import now, now_plus, seconds_to, try_parse_datetime_str +from key_value.shared.type_checking.bear_spray import bear_enforce +from key_value.shared.utils.time_to_live import now, now_plus, seconds_to @dataclass(kw_only=True) @@ -61,96 +62,8 @@ def from_ttl(cls, *, value: Mapping[str, Any], created_at: datetime | None = Non expires_at=(now_plus(seconds=float(ttl)) if ttl else None), ) - def to_dict( - self, include_metadata: bool = True, include_expiration: bool = True, include_creation: bool = True, stringify_value: bool = False - ) -> dict[str, Any]: - if not include_metadata: - return dict(self.value) - - data: dict[str, Any] = {"value": self.value_as_json if stringify_value else self.value} - - if include_creation and self.created_at: - data["created_at"] = self.created_at.isoformat() - if include_expiration and self.expires_at: - data["expires_at"] = self.expires_at.isoformat() - - return data - - def to_json( - self, include_metadata: bool = True, include_expiration: bool = True, include_creation: bool = True, stringify_value: bool = False - ) -> str: - return dump_to_json( - obj=self.to_dict( - include_metadata=include_metadata, - include_expiration=include_expiration, - include_creation=include_creation, - stringify_value=stringify_value, - ) - ) - - @classmethod - def from_dict( # noqa: PLR0912 - cls, - data: dict[str, Any], - includes_metadata: bool = True, - stringified_value: bool = False, - ) -> Self: - if not includes_metadata: - return cls( - value=data, - ) - - created_at: datetime | None = None - expires_at: datetime | None = None - - if created_at_value := data.get("created_at"): - if isinstance(created_at_value, str): - created_at = try_parse_datetime_str(value=created_at_value) - elif isinstance(created_at_value, datetime): - created_at = created_at_value - else: - msg = "Expected `created_at` field to be a string or datetime" - raise DeserializationError(msg) - - if expires_at_value := data.get("expires_at"): - if isinstance(expires_at_value, str): - expires_at = try_parse_datetime_str(value=expires_at_value) - elif isinstance(expires_at_value, datetime): - expires_at = expires_at_value - else: - msg = "Expected `expires_at` field to be a string or datetime" - raise DeserializationError(msg) - - if not (raw_value := data.get("value")): - msg = "Value is None" - raise DeserializationError(msg) - - value: dict[str, Any] - - if stringified_value: - if not isinstance(raw_value, str): - msg = "Value is not a string" - raise DeserializationError(msg) - value = load_from_json(json_str=raw_value) - else: - if not isinstance(raw_value, dict): - msg = "Value is not a dictionary" - raise DeserializationError(msg) - value = verify_dict(obj=raw_value) - - return cls( - created_at=created_at, - expires_at=expires_at, - value=value, - ) - - @classmethod - def from_json(cls, json_str: str, includes_metadata: bool = True) -> Self: - data: dict[str, Any] = load_from_json(json_str=json_str) - - return cls.from_dict(data=data, includes_metadata=includes_metadata) - +@bear_enforce def dump_to_json(obj: dict[str, Any]) -> str: try: return json.dumps(obj, sort_keys=True) @@ -159,6 +72,7 @@ def dump_to_json(obj: dict[str, Any]) -> str: raise SerializationError(msg) from e +@bear_enforce def load_from_json(json_str: str) -> dict[str, Any]: try: return verify_dict(obj=json.loads(json_str)) # pyright: ignore[reportAny] @@ -168,6 +82,7 @@ def load_from_json(json_str: str) -> dict[str, Any]: raise DeserializationError(msg) from e +@bear_enforce def verify_dict(obj: Any) -> dict[str, Any]: if not isinstance(obj, Mapping): msg = "Object is not a dictionary" diff --git a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py index bebdab77..20b6d2d0 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py @@ -131,3 +131,18 @@ def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: return data + + +class ValueOnlySerializationAdapter(SerializationAdapter): + """Serialization adapter that only serializes the value.""" + + def __init__(self, *, value_format: Literal["string", "dict"] | None = "dict") -> None: + super().__init__(date_format=None, value_format=value_format) + + def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "value": data, + } + + def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: + return data.get("value") diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py index 325b5c30..5e579c27 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/disk/multi_store.py @@ -6,8 +6,8 @@ from pathlib import Path from typing import overload -from key_value.shared.utils.compound import compound_key from key_value.shared.utils.managed_entry import ManagedEntry, datetime +from key_value.shared.utils.serialization import BasicSerializationAdapter from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseContextManagerStore, BaseStore @@ -103,6 +103,7 @@ def default_disk_cache_factory(collection: str) -> Cache: self._cache = {} self._stable_api = True + self._serialization_adapter = BasicSerializationAdapter() super().__init__(default_collection=default_collection) @@ -112,11 +113,9 @@ def _setup_collection(self, *, collection: str) -> None: @override def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: - combo_key: str = compound_key(collection=collection, key=key) - expire_epoch: float - (managed_entry_str, expire_epoch) = self._cache[collection].get(key=combo_key, expire_time=True) # pyright: ignore[reportAny] + (managed_entry_str, expire_epoch) = self._cache[collection].get(key=key, expire_time=True) # pyright: ignore[reportAny] if not isinstance(managed_entry_str, str): return None @@ -130,17 +129,11 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: - combo_key: str = compound_key(collection=collection, key=key) - - _ = self._cache[collection].set( - key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry, exclude_none=False), expire=managed_entry.ttl - ) + _ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl) @override def _delete_managed_entry(self, *, key: str, collection: str) -> bool: - combo_key: str = compound_key(collection=collection, key=key) - - return self._cache[collection].delete(key=combo_key, retry=True) + return self._cache[collection].delete(key=key, retry=True) def _sync_close(self) -> None: for cache in self._cache.values(): diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py index f3eb41c8..db868392 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/keyring/store.py @@ -69,7 +69,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if json_str is None: return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -78,7 +78,7 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) keyring.set_password(service_name=self._service_name, username=combo_key, password=json_str) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py index b206291a..08ea7838 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/memory/store.py @@ -7,7 +7,8 @@ from typing import Any from key_value.shared.utils.managed_entry import ManagedEntry -from typing_extensions import Self, override +from key_value.shared.utils.serialization import BasicSerializationAdapter +from typing_extensions import override from key_value.sync.code_gen.stores.base import ( SEED_DATA_TYPE, @@ -32,13 +33,6 @@ class MemoryCacheEntry: expires_at: datetime | None - @classmethod - def from_managed_entry(cls, managed_entry: ManagedEntry) -> Self: - return cls(json_str=managed_entry.to_json(), expires_at=managed_entry.expires_at) - - def to_managed_entry(self) -> ManagedEntry: - return ManagedEntry.from_json(json_str=self.json_str) - def _memory_cache_ttu(_key: Any, value: MemoryCacheEntry, _now: float) -> float: """Calculate time-to-use for cache entries based on their expiration time.""" @@ -53,8 +47,6 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: return 1 -DEFAULT_MAX_ENTRIES_PER_COLLECTION = 10000 - DEFAULT_PAGE_SIZE = 10000 PAGE_LIMIT = 10000 @@ -62,13 +54,17 @@ def _memory_cache_getsizeof(value: MemoryCacheEntry) -> int: class MemoryCollection: _cache: TLRUCache[str, MemoryCacheEntry] - def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION): + def __init__(self, max_entries: int | None = None): """Initialize a fixed-size in-memory collection. Args: - max_entries: The maximum number of entries per collection. Defaults to 10,000 entries. + max_entries: The maximum number of entries per collection. Defaults to no limit. """ - self._cache = TLRUCache[str, MemoryCacheEntry](maxsize=max_entries, ttu=_memory_cache_ttu, getsizeof=_memory_cache_getsizeof) + self._cache = TLRUCache[str, MemoryCacheEntry]( + maxsize=max_entries or sys.maxsize, ttu=_memory_cache_ttu, getsizeof=_memory_cache_getsizeof + ) + + self._serialization_adapter = BasicSerializationAdapter() def get(self, key: str) -> ManagedEntry | None: managed_entry_str: MemoryCacheEntry | None = self._cache.get(key) @@ -76,12 +72,13 @@ def get(self, key: str) -> ManagedEntry | None: if managed_entry_str is None: return None - managed_entry: ManagedEntry = managed_entry_str.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str.json_str) return managed_entry def put(self, key: str, value: ManagedEntry) -> None: - self._cache[key] = MemoryCacheEntry.from_managed_entry(managed_entry=value) + json_str: str = self._serialization_adapter.dump_json(entry=value) + self._cache[key] = MemoryCacheEntry(json_str=json_str, expires_at=value.expires_at) def delete(self, key: str) -> bool: return self._cache.pop(key, None) is not None @@ -99,28 +96,26 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol _cache: dict[str, MemoryCollection] def __init__( - self, - *, - max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, - default_collection: str | None = None, - seed: SEED_DATA_TYPE | None = None, + self, *, max_entries_per_collection: int | None = None, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None ): """Initialize a fixed-size in-memory store. Args: - max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000. + max_entries_per_collection: The maximum number of entries per collection. Defaults to no limit. default_collection: The default collection to use if no collection is provided. seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}. Each value must be a mapping (dict) that will be stored as the entry's value. Seeding occurs lazily when each collection is first accessed. """ - self.max_entries_per_collection = max_entries_per_collection + self.max_entries_per_collection = max_entries_per_collection or sys.maxsize self._cache = {} self._stable_api = True + self._serialization_adapter = BasicSerializationAdapter() + super().__init__(default_collection=default_collection, seed=seed) @override diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py index 4356ac3d..91da8d23 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py @@ -115,7 +115,7 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage self._fail_on_closed_store() combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) self._db[combo_key] = json_value.encode("utf-8") @@ -138,7 +138,7 @@ def _put_managed_entries( batch = WriteBatch() for key, managed_entry in zip(keys, managed_entries, strict=True): combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) batch.put(combo_key, json_value.encode("utf-8")) self._db.write(batch) diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py index d24d1d19..5725b4b5 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/simple/store.py @@ -1,12 +1,14 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'store.py' # DO NOT CHANGE! Change the original file instead. +import sys from collections import defaultdict from dataclasses import dataclass from datetime import datetime from key_value.shared.utils.compound import compound_key, get_collections_from_compound_keys, get_keys_from_compound_keys from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json +from key_value.shared.utils.serialization import ValueOnlySerializationAdapter from key_value.shared.utils.time_to_live import seconds_to from typing_extensions import override @@ -44,19 +46,21 @@ class SimpleStore(BaseEnumerateCollectionsStore, BaseEnumerateKeysStore, BaseDes _data: dict[str, SimpleStoreEntry] - def __init__(self, max_entries: int = DEFAULT_SIMPLE_STORE_MAX_ENTRIES, default_collection: str | None = None): + def __init__(self, max_entries: int | None = None, default_collection: str | None = None): """Initialize the simple store. Args: - max_entries: The maximum number of entries to store. Defaults to 10000. + max_entries: The maximum number of entries to store. Defaults to no limit. default_collection: The default collection to use if no collection is provided. """ - self.max_entries = max_entries + super().__init__(default_collection=default_collection) + + self.max_entries = max_entries or sys.maxsize self._data = defaultdict[str, SimpleStoreEntry]() - super().__init__(default_collection=default_collection) + self._serialization_adapter = ValueOnlySerializationAdapter() @override def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None: @@ -67,7 +71,10 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if store_entry is None: return None - return store_entry.to_managed_entry() + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=store_entry.json_str) + managed_entry.expires_at = store_entry.expires_at + managed_entry.created_at = store_entry.created_at + return managed_entry @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: @@ -77,7 +84,9 @@ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: Manage _ = self._data.pop(next(iter(self._data))) self._data[combo_key] = SimpleStoreEntry( - json_str=managed_entry.to_json(include_metadata=False), expires_at=managed_entry.expires_at, created_at=managed_entry.created_at + json_str=self._serialization_adapter.dump_json(entry=managed_entry), + expires_at=managed_entry.expires_at, + created_at=managed_entry.created_at, ) @override diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py index 6ec5c0a8..7c649697 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py @@ -112,7 +112,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ for response in responses: if isinstance(response, bytes): decoded_response: str = response.decode("utf-8") - entries.append(ManagedEntry.from_json(json_str=decoded_response)) + entries.append(self._serialization_adapter.load_json(json_str=decoded_response)) else: entries.append(None) @@ -122,7 +122,7 @@ def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_value: str = managed_entry.to_json() + json_value: str = self._serialization_adapter.dump_json(entry=managed_entry) expiry: ExpirySet | None = ExpirySet(expiry_type=ExpiryType.SEC, value=int(managed_entry.ttl)) if managed_entry.ttl else None diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py index 6cea8ada..85237d2c 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/vault/store.py @@ -99,13 +99,13 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None json_str: str = secret_data["value"] # pyright: ignore[reportUnknownVariableType] - return ManagedEntry.from_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] + return self._serialization_adapter.load_json(json_str=json_str) # pyright: ignore[reportUnknownArgumentType] @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: combo_key: str = compound_key(collection=collection, key=key) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) # Store the JSON string in a 'value' field secret_data = {"value": json_str} diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py index 30bec383..20024851 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/windows_registry/store.py @@ -89,14 +89,14 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non if not (json_str := get_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key)): return None - return ManagedEntry.from_json(json_str=json_str) + return self._serialization_adapter.load_json(json_str=json_str) @override def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None: sanitized_key = self._sanitize_key(key=key) registry_path = self._get_registry_path(collection=collection) - json_str: str = managed_entry.to_json() + json_str: str = self._serialization_adapter.dump_json(entry=managed_entry) set_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key, value=json_str) diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py index 994738a2..b904f35f 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_disk.py @@ -1,10 +1,14 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_disk.py' # DO NOT CHANGE! Change the original file instead. +import json import tempfile from collections.abc import Generator import pytest +from dirty_equals import IsDatetime +from diskcache.core import Cache +from inline_snapshot import snapshot from typing_extensions import override from key_value.sync.code_gen.stores.disk import DiskStore @@ -25,3 +29,22 @@ def store(self, disk_store: DiskStore) -> DiskStore: disk_store._cache.clear() # pyright: ignore[reportPrivateUsage] return disk_store + + @pytest.fixture + def disk_cache(self, disk_store: DiskStore) -> Cache: + return disk_store._cache # pyright: ignore[reportPrivateUsage] + + def test_value_stored(self, store: DiskStore, disk_cache: Cache): + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test::test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py index e6341075..ed68c109 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py +++ b/key-value/key-value-sync/tests/code_gen/stores/disk/test_multi_disk.py @@ -1,16 +1,23 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_multi_disk.py' # DO NOT CHANGE! Change the original file instead. +import json import tempfile from collections.abc import Generator from pathlib import Path +from typing import TYPE_CHECKING import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot from typing_extensions import override from key_value.sync.code_gen.stores.disk.multi_store import MultiDiskStore from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin +if TYPE_CHECKING: + from diskcache.core import Cache + TEST_SIZE_LIMIT = 100 * 1024 # 100KB @@ -27,3 +34,19 @@ def store(self, multi_disk_store: MultiDiskStore) -> MultiDiskStore: multi_disk_store._cache[collection].clear() # pyright: ignore[reportPrivateUsage] return multi_disk_store + + def test_value_stored(self, store: MultiDiskStore): + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + disk_cache: Cache = store._cache["test"] # pyright: ignore[reportPrivateUsage] + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot({"value": {"name": "Alice", "age": 30}, "created_at": IsDatetime(iso_string=True)}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = disk_cache.get(key="test_key") + value_as_dict = json.loads(value) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py index 2e04cd08..a03509a0 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py +++ b/key-value/key-value-sync/tests/code_gen/stores/rocksdb/test_rocksdb.py @@ -1,11 +1,15 @@ # WARNING: this file is auto-generated by 'build_sync_library.py' # from the original file 'test_rocksdb.py' # DO NOT CHANGE! Change the original file instead. +import json from collections.abc import Generator from pathlib import Path from tempfile import TemporaryDirectory import pytest +from dirty_equals import IsDatetime +from inline_snapshot import snapshot +from rocksdict import Rdict from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseStore @@ -62,3 +66,24 @@ def test_rocksdb_db_connection(self): @pytest.mark.skip(reason="Local disk stores are unbounded") @override def test_not_unbounded(self, store: BaseStore): ... + + @pytest.fixture + def rocksdb_client(self, store: RocksDBStore) -> Rdict: + return store._db # pyright: ignore[reportPrivateUsage] + + def test_value_stored(self, store: RocksDBStore, rocksdb_client: Rdict): + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = rocksdb_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py index 69349194..14aa575c 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py +++ b/key-value/key-value-sync/tests/code_gen/stores/valkey/test_valkey.py @@ -2,13 +2,18 @@ # from the original file 'test_valkey.py' # DO NOT CHANGE! Change the original file instead. import contextlib +import json from collections.abc import Generator import pytest +from dirty_equals import IsDatetime +from glide_sync.glide_client import BaseClient +from inline_snapshot import snapshot from key_value.shared.stores.wait import wait_for_true from typing_extensions import override from key_value.sync.code_gen.stores.base import BaseStore +from key_value.sync.code_gen.stores.valkey import ValkeyStore from tests.code_gen.conftest import detect_on_windows, docker_container, should_skip_docker_tests from tests.code_gen.stores.base import BaseStoreTests, ContextManagerStoreTestMixin @@ -79,6 +84,27 @@ def store(self, setup_valkey: None): return store + @pytest.fixture + def valkey_client(self, store: ValkeyStore): + return store._connected_client # pyright: ignore[reportPrivateUsage] + @pytest.mark.skip(reason="Distributed Caches are unbounded") @override def test_not_unbounded(self, store: BaseStore): ... + + def test_value_stored(self, store: ValkeyStore, valkey_client: BaseClient): + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}) + + value = valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}) + + store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10) + + value = valkey_client.get(key="test::test_key") + assert value is not None + value_as_dict = json.loads(value.decode("utf-8")) + assert value_as_dict == snapshot( + {"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)} + ) diff --git a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py index e404249a..74bbba54 100644 --- a/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py +++ b/key-value/key-value-sync/tests/code_gen/stores/wrappers/test_limit_size.py @@ -3,6 +3,7 @@ # DO NOT CHANGE! Change the original file instead. import pytest from key_value.shared.errors.wrappers.limit_size import EntryTooLargeError, EntryTooSmallError +from key_value.shared.utils.managed_entry import estimate_serialized_size from typing_extensions import override from key_value.sync.code_gen.stores.memory.store import MemoryStore @@ -140,12 +141,8 @@ def test_put_many_all_too_large_without_raise(self, memory_store: MemoryStore): def test_exact_size_limit(self, memory_store: MemoryStore): # First, determine the exact size of a small value - from key_value.shared.utils.managed_entry import ManagedEntry - test_value = {"test": "value"} - managed_entry = ManagedEntry(value=test_value) - json_str = managed_entry.to_json() - exact_size = len(json_str.encode("utf-8")) + exact_size = estimate_serialized_size(value=test_value) # Create a store with exact size limit limit_size_store: LimitSizeWrapper = LimitSizeWrapper(key_value=memory_store, max_size=exact_size, raise_on_too_large=True) From 3ffd5400d8837c76eb1acc925a88b2e390332874 Mon Sep 17 00:00:00 2001 From: William Easton Date: Thu, 30 Oct 2025 18:25:36 -0500 Subject: [PATCH 10/11] typecheck fixes --- .../key-value-aio/src/key_value/aio/stores/rocksdb/store.py | 2 +- .../key-value-aio/src/key_value/aio/stores/valkey/store.py | 2 +- .../src/key_value/sync/code_gen/stores/rocksdb/store.py | 2 +- .../src/key_value/sync/code_gen/stores/valkey/store.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py b/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py index 0a6542b6..2c0829d5 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/rocksdb/store.py @@ -109,7 +109,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry return None managed_entry_str: str = value.decode("utf-8") - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str) + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) return managed_entry diff --git a/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py b/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py index 57398ee1..71e02cfc 100644 --- a/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py +++ b/key-value/key-value-aio/src/key_value/aio/stores/valkey/store.py @@ -95,7 +95,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry decoded_response: str = response.decode("utf-8") - return ManagedEntry.from_json(json_str=decoded_response) + return self._serialization_adapter.load_json(json_str=decoded_response) @override async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py index 91da8d23..72d510b7 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/rocksdb/store.py @@ -106,7 +106,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non return None managed_entry_str: str = value.decode("utf-8") - managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str) + managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str) return managed_entry diff --git a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py index 7c649697..8c8c0a8e 100644 --- a/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py +++ b/key-value/key-value-sync/src/key_value/sync/code_gen/stores/valkey/store.py @@ -97,7 +97,7 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non decoded_response: str = response.decode("utf-8") - return ManagedEntry.from_json(json_str=decoded_response) + return self._serialization_adapter.load_json(json_str=decoded_response) @override def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]: From 194bc9464d8b587f24ab41acc3a8e702814b0fe1 Mon Sep 17 00:00:00 2001 From: William Easton Date: Thu, 30 Oct 2025 18:30:06 -0500 Subject: [PATCH 11/11] Fix typecheck --- .../src/key_value/shared/utils/serialization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py index 20b6d2d0..158d1934 100644 --- a/key-value/key-value-shared/src/key_value/shared/utils/serialization.py +++ b/key-value/key-value-shared/src/key_value/shared/utils/serialization.py @@ -145,4 +145,7 @@ def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]: } def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]: - return data.get("value") + if "value" not in data: + msg = "Value field not found" + raise DeserializationError(message=msg) + return verify_dict(obj=data["value"])