Skip to content
Closed
13 changes: 8 additions & 5 deletions key-value/key-value-aio/src/key_value/aio/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from key_value.shared.errors import StoreSetupError
from key_value.shared.type_checking.bear_spray import bear_enforce
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
from typing_extensions import Self, override

Expand Down Expand Up @@ -67,6 +68,8 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
_setup_collection_locks: defaultdict[str, Lock]
_setup_collection_complete: defaultdict[str, bool]

_serialization_adapter: SerializationAdapter

_seed: FROZEN_SEED_DATA_TYPE

default_collection: str
Expand All @@ -91,6 +94,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP

self.default_collection = default_collection or DEFAULT_COLLECTION_NAME

self._serialization_adapter = BasicSerializationAdapter()

if not hasattr(self, "_stable_api"):
self._stable_api = False

Expand Down Expand Up @@ -286,9 +291,9 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non
collection = collection or self.default_collection
await self.setup_collection(collection=collection)

created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
created_at, _, expires_at = prepare_entry_timestamps(ttl=ttl)

managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at)
managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at)

await self._put_managed_entry(
collection=collection,
Expand Down Expand Up @@ -316,9 +321,7 @@ async def put_many(

created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)

managed_entries: list[ManagedEntry] = [
ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values
]
managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values]

await self._put_managed_entries(
collection=collection,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
from collections.abc import Callable
from datetime import timezone
from pathlib import Path
from typing import overload

from key_value.shared.utils.compound import compound_key
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.managed_entry import ManagedEntry, datetime
from key_value.shared.utils.serialization import BasicSerializationAdapter
from typing_extensions import override

from key_value.aio.stores.base import BaseContextManagerStore, BaseStore
Expand Down Expand Up @@ -100,6 +100,7 @@ def default_disk_cache_factory(collection: str) -> Cache:
self._cache = {}

self._stable_api = True
self._serialization_adapter = BasicSerializationAdapter()

super().__init__(default_collection=default_collection)

Expand All @@ -109,18 +110,17 @@ async def _setup_collection(self, *, collection: str) -> None:

@override
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
combo_key: str = compound_key(collection=collection, key=key)

expire_epoch: float

managed_entry_str, expire_epoch = self._cache[collection].get(key=combo_key, expire_time=True) # pyright: ignore[reportAny]
managed_entry_str, expire_epoch = self._cache[collection].get(key=key, expire_time=True) # pyright: ignore[reportAny]

if not isinstance(managed_entry_str, str):
return None

ttl = (expire_epoch - time.time()) if expire_epoch else None
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str)

managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl)
if expire_epoch:
managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc)

return managed_entry

Expand All @@ -132,15 +132,11 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
combo_key: str = compound_key(collection=collection, key=key)

_ = self._cache[collection].set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl)
_ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
combo_key: str = compound_key(collection=collection, key=key)

return self._cache[collection].delete(key=combo_key, retry=True)
return self._cache[collection].delete(key=key, retry=True)

def _sync_close(self) -> None:
for cache in self._cache.values():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import overload

Expand Down Expand Up @@ -90,9 +90,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
if not isinstance(managed_entry_str, str):
return None

ttl = (expire_epoch - time.time()) if expire_epoch else None
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str)

managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl)
if expire_epoch:
managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc)

return managed_entry

Expand All @@ -106,7 +107,7 @@ async def _put_managed_entry(
) -> None:
combo_key: str = compound_key(collection=collection, key=key)

_ = self._cache.set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl)
_ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
if not json_value:
return None

return ManagedEntry.from_json(json_str=json_value) # pyright: ignore[reportUnknownArgumentType]
return self._serialization_adapter.load_json(json_str=json_value)

@override
async def _put_managed_entry(
Expand All @@ -210,7 +210,7 @@ async def _put_managed_entry(
managed_entry: ManagedEntry,
) -> None:
"""Store a managed entry in DynamoDB."""
json_value = managed_entry.to_json()
json_value = self._serialization_adapter.dump_json(entry=managed_entry)

item: dict[str, Any] = {
"collection": {"S": collection},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from elastic_transport import ObjectApiResponse
from elastic_transport import SerializationError as ElasticsearchSerializationError
from key_value.shared.errors import DeserializationError, SerializationError
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.sanitize import (
ALPHANUMERIC_CHARACTERS,
LOWERCASE_ALPHABET,
NUMBERS,
sanitize_string,
)
from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str
from key_value.shared.utils.serialization import SerializationAdapter
from key_value.shared.utils.time_to_live import now_as_epoch
from typing_extensions import override

from key_value.aio.stores.base import (
Expand Down Expand Up @@ -84,52 +85,50 @@
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]:
document: dict[str, Any] = {"collection": collection, "key": key, "value": {}}
class ElasticsearchSerializationAdapter(SerializationAdapter):
"""Adapter for Elasticsearch with support for native and string storage modes."""

# Store in appropriate field based on mode
if native_storage:
document["value"]["flattened"] = managed_entry.value_as_dict
else:
document["value"]["string"] = managed_entry.value_as_json
_native_storage: bool

if managed_entry.created_at:
document["created_at"] = managed_entry.created_at.isoformat()
if managed_entry.expires_at:
document["expires_at"] = managed_entry.expires_at.isoformat()
def __init__(self, *, native_storage: bool = True) -> None:
"""Initialize the Elasticsearch adapter.

return document
Args:
native_storage: If True (default), store values as flattened dicts.
If False, store values as JSON strings.
"""
super().__init__()

self._native_storage = native_storage
self._date_format = "isoformat"
self._value_format = "dict" if native_storage else "string"

def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
value: dict[str, Any] = {}
@override
def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
value = data.pop("value")

data["value"] = {}

raw_value = source.get("value")
if self._native_storage:
data["value"]["flattened"] = value
else:
data["value"]["string"] = value

# Try flattened field first, fall back to string field
if not raw_value or not isinstance(raw_value, dict):
msg = "Value field not found or invalid type"
raise DeserializationError(msg)
return data

if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
value = verify_dict(obj=value_flattened)
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
if not isinstance(value_str, str):
msg = "Value in `value` field is not a string"
raise DeserializationError(msg)
value = load_from_json(value_str)
else:
msg = "Value field not found or invalid type"
raise DeserializationError(msg)
@override
def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]:
value = data.pop("value")

created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
if "flattened" in value:
data["value"] = value["flattened"]
elif "string" in value:
data["value"] = value["string"]
else:
msg = "Value field not found in Elasticsearch document"
raise DeserializationError(message=msg)

return ManagedEntry(
value=value,
created_at=created_at,
expires_at=expires_at,
)
return data


class ElasticsearchStore(
Expand All @@ -145,6 +144,8 @@ class ElasticsearchStore(

_native_storage: bool

_adapter: SerializationAdapter

@overload
def __init__(
self,
Expand Down Expand Up @@ -208,6 +209,7 @@ def __init__(
self._index_prefix = index_prefix
self._native_storage = native_storage
self._is_serverless = False
self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage)

super().__init__(default_collection=default_collection)

Expand Down Expand Up @@ -260,7 +262,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
return None

try:
return source_to_managed_entry(source=source)
return self._adapter.load_dict(data=source)
except DeserializationError:
return None

Expand Down Expand Up @@ -293,7 +295,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
continue

try:
entries_by_id[doc_id] = source_to_managed_entry(source=source)
entries_by_id[doc_id] = self._adapter.load_dict(data=source)
except DeserializationError as e:
logger.error(
"Failed to deserialize Elasticsearch document in batch operation",
Expand Down Expand Up @@ -324,9 +326,7 @@ async def _put_managed_entry(
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)

document: dict[str, Any] = managed_entry_to_document(
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
)
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)

try:
_ = await self._client.index(
Expand Down Expand Up @@ -364,11 +364,10 @@ async def _put_managed_entries(

index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)

document: dict[str, Any] = managed_entry_to_document(
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
)
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)

operations.extend([index_action, document])

try:
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
except ElasticsearchSerializationError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
if json_str is None:
return None

return ManagedEntry.from_json(json_str=json_str)
return self._serialization_adapter.load_json(json_str=json_str)

@override
async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None:
Expand All @@ -88,7 +88,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry:

combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)

json_str: str = managed_entry.to_json()
json_str: str = self._serialization_adapter.dump_json(entry=managed_entry)

keyring.set_password(service_name=self._service_name, username=combo_key, password=json_str)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from key_value.shared.utils.compound import compound_key
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.serialization import BasicSerializationAdapter
from typing_extensions import override

from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyStore, BaseStore
Expand Down Expand Up @@ -44,9 +45,11 @@ def __init__(
port: Memcached port. Defaults to 11211.
default_collection: The default collection to use if no collection is provided.
"""
super().__init__(default_collection=default_collection)

self._client = client or Client(host=host, port=port)

super().__init__(default_collection=default_collection)
self._serialization_adapter = BasicSerializationAdapter(value_format="dict")
Comment on lines +48 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical initialization order issue.

The code calls super().__init__(default_collection=default_collection) on line 48 before initializing self._serialization_adapter on line 52. If the parent class initialization triggers any methods that depend on the adapter (e.g., seed data processing, setup hooks), this will cause an AttributeError.

Apply this diff to initialize the adapter before calling the parent constructor:

-        super().__init__(default_collection=default_collection)
-
         self._client = client or Client(host=host, port=port)
 
         self._serialization_adapter = BasicSerializationAdapter(value_format="dict")
+
+        super().__init__(default_collection=default_collection)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
super().__init__(default_collection=default_collection)
self._client = client or Client(host=host, port=port)
super().__init__(default_collection=default_collection)
self._serialization_adapter = BasicSerializationAdapter(value_format="dict")
self._client = client or Client(host=host, port=port)
self._serialization_adapter = BasicSerializationAdapter(value_format="dict")
super().__init__(default_collection=default_collection)
🤖 Prompt for AI Agents
In key-value/key-value-aio/src/key_value/aio/stores/memcached/store.py around
lines 48 to 52, the adapter is initialized after calling super().__init__, which
can cause AttributeError if the parent constructor calls methods that use the
adapter; move the line that sets self._serialization_adapter =
BasicSerializationAdapter(value_format="dict") to before the
super().__init__(default_collection=default_collection) call so the adapter is
ready during base-class initialization.


def sanitize_key(self, key: str) -> str:
if len(key) > MAX_KEY_LENGTH:
Expand All @@ -65,7 +68,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry

json_str: str = raw_value.decode(encoding="utf-8")

return ManagedEntry.from_json(json_str=json_str)
return self._serialization_adapter.load_json(json_str=json_str)

@override
async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) -> list[ManagedEntry | None]:
Expand All @@ -82,7 +85,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
for raw_value in raw_values:
if isinstance(raw_value, (bytes, bytearray)):
json_str: str = raw_value.decode(encoding="utf-8")
entries.append(ManagedEntry.from_json(json_str=json_str))
entries.append(self._serialization_adapter.load_json(json_str=json_str))
else:
entries.append(None)

Expand All @@ -106,7 +109,7 @@ async def _put_managed_entry(
else:
exptime = max(int(managed_entry.ttl), 1)

json_value: str = managed_entry.to_json()
json_value: str = self._serialization_adapter.dump_json(entry=managed_entry)

_ = await self._client.set(
key=combo_key.encode(encoding="utf-8"),
Expand Down
Loading
Loading