Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
796e8ec
refactor: implement serialization adapter pattern for MongoDB and Red…
github-actions[bot] Oct 29, 2025
0ab6766
refactor: consolidate serialization adapters into shared module
github-actions[bot] Oct 29, 2025
49cf605
refactor: move store-specific serialization adapters into store modules
github-actions[bot] Oct 29, 2025
714b0cb
Refactor serialization
strawgate Oct 30, 2025
17743fe
Updates for tests
strawgate Oct 30, 2025
0dbf7d6
More clean-up
strawgate Oct 30, 2025
8c4b6d3
PR Cleanup
strawgate Oct 30, 2025
59cbaee
fix redis type checks
strawgate Oct 30, 2025
2cb6d1a
Finish refactor
strawgate Oct 30, 2025
3ffd540
typecheck fixes
strawgate Oct 30, 2025
e519537
Merge branch 'main' into claude/issue-165-20251029-0155
strawgate Oct 30, 2025
194bc94
Fix typecheck
strawgate Oct 30, 2025
d62a400
Fixes for store tests
strawgate Oct 30, 2025
7b9229b
Fix for dynamo tests
strawgate Oct 31, 2025
935222b
Small store fixes
strawgate Oct 31, 2025
216eaad
fix codegen
strawgate Oct 31, 2025
9feeba6
fix valkey tests
strawgate Oct 31, 2025
30aebe2
Remove ValueOnlySerializer
strawgate Oct 31, 2025
ea90880
Fix windows tests
strawgate Oct 31, 2025
65a365b
PR Feedback
strawgate Oct 31, 2025
e648e68
codegen
strawgate Oct 31, 2025
5d41ae7
small pr updates
strawgate Oct 31, 2025
cabd1b4
cleanup test warnings
strawgate Oct 31, 2025
39124b7
Fewer errors for Vault
strawgate Oct 31, 2025
eb94c25
block incompatible adapter settings
strawgate Oct 31, 2025
5ba24e5
fix: use expires_at directly for DynamoDB TTL timestamp
github-actions[bot] Oct 31, 2025
bdebfda
Add message to assert
strawgate Oct 31, 2025
26d4421
Widen allowed gap for TTL response from store
strawgate Oct 31, 2025
c6166aa
codegen too
strawgate Oct 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions key-value/key-value-aio/src/key_value/aio/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from key_value.shared.errors import StoreSetupError
from key_value.shared.type_checking.bear_spray import bear_enforce
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.serialization import BasicSerializationAdapter, SerializationAdapter
from key_value.shared.utils.time_to_live import prepare_entry_timestamps
from typing_extensions import Self, override

Expand Down Expand Up @@ -67,14 +68,23 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
_setup_collection_locks: defaultdict[str, Lock]
_setup_collection_complete: defaultdict[str, bool]

_serialization_adapter: SerializationAdapter

_seed: FROZEN_SEED_DATA_TYPE

default_collection: str

def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None:
def __init__(
self,
*,
serialization_adapter: SerializationAdapter | None = None,
default_collection: str | None = None,
seed: SEED_DATA_TYPE | None = None,
) -> None:
"""Initialize the managed key-value store.

Args:
serialization_adapter: The serialization adapter to use for the store.
default_collection: The default collection to use if no collection is provided.
Defaults to "default_collection".
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
Expand All @@ -91,6 +101,8 @@ def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYP

self.default_collection = default_collection or DEFAULT_COLLECTION_NAME

self._serialization_adapter = serialization_adapter or BasicSerializationAdapter()

if not hasattr(self, "_stable_api"):
self._stable_api = False

Expand Down Expand Up @@ -286,9 +298,9 @@ async def put(self, key: str, value: Mapping[str, Any], *, collection: str | Non
collection = collection or self.default_collection
await self.setup_collection(collection=collection)

created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)
created_at, _, expires_at = prepare_entry_timestamps(ttl=ttl)

managed_entry: ManagedEntry = ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at)
managed_entry: ManagedEntry = ManagedEntry(value=value, created_at=created_at, expires_at=expires_at)

await self._put_managed_entry(
collection=collection,
Expand Down Expand Up @@ -316,9 +328,7 @@ async def put_many(

created_at, ttl_seconds, expires_at = prepare_entry_timestamps(ttl=ttl)

managed_entries: list[ManagedEntry] = [
ManagedEntry(value=value, ttl=ttl_seconds, created_at=created_at, expires_at=expires_at) for value in values
]
managed_entries: list[ManagedEntry] = [ManagedEntry(value=value, created_at=created_at, expires_at=expires_at) for value in values]

await self._put_managed_entries(
collection=collection,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
from collections.abc import Callable
from datetime import timezone
from pathlib import Path
from typing import overload

from key_value.shared.utils.compound import compound_key
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.managed_entry import ManagedEntry, datetime
from key_value.shared.utils.serialization import BasicSerializationAdapter
from typing_extensions import override

from key_value.aio.stores.base import BaseContextManagerStore, BaseStore
Expand Down Expand Up @@ -100,6 +100,7 @@ def default_disk_cache_factory(collection: str) -> Cache:
self._cache = {}

self._stable_api = True
self._serialization_adapter = BasicSerializationAdapter()

super().__init__(default_collection=default_collection)

Expand All @@ -109,18 +110,17 @@ async def _setup_collection(self, *, collection: str) -> None:

@override
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
combo_key: str = compound_key(collection=collection, key=key)

expire_epoch: float

managed_entry_str, expire_epoch = self._cache[collection].get(key=combo_key, expire_time=True) # pyright: ignore[reportAny]
managed_entry_str, expire_epoch = self._cache[collection].get(key=key, expire_time=True) # pyright: ignore[reportAny]

if not isinstance(managed_entry_str, str):
return None

ttl = (expire_epoch - time.time()) if expire_epoch else None
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str)

managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl)
if expire_epoch:
managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc)

return managed_entry

Expand All @@ -132,15 +132,11 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
combo_key: str = compound_key(collection=collection, key=key)

_ = self._cache[collection].set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl)
_ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
combo_key: str = compound_key(collection=collection, key=key)

return self._cache[collection].delete(key=combo_key, retry=True)
return self._cache[collection].delete(key=key, retry=True)

def _sync_close(self) -> None:
for cache in self._cache.values():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import overload

Expand Down Expand Up @@ -90,9 +90,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
if not isinstance(managed_entry_str, str):
return None

ttl = (expire_epoch - time.time()) if expire_epoch else None
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=managed_entry_str)

managed_entry: ManagedEntry = ManagedEntry.from_json(json_str=managed_entry_str, ttl=ttl)
if expire_epoch:
managed_entry.expires_at = datetime.fromtimestamp(expire_epoch, tz=timezone.utc)

return managed_entry

Expand All @@ -106,7 +107,7 @@ async def _put_managed_entry(
) -> None:
combo_key: str = compound_key(collection=collection, key=key)

_ = self._cache.set(key=combo_key, value=managed_entry.to_json(include_expiration=False), expire=managed_entry.ttl)
_ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
Expand Down
23 changes: 16 additions & 7 deletions key-value/key-value-aio/src/key_value/aio/stores/dynamodb/store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timezone
from types import TracebackType
from typing import TYPE_CHECKING, Any, overload

Expand Down Expand Up @@ -183,23 +184,31 @@ async def _setup(self) -> None:
@override
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
"""Retrieve a managed entry from DynamoDB."""
response = await self._connected_client.get_item( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
response = await self._connected_client.get_item(
TableName=self._table_name,
Key={
"collection": {"S": collection},
"key": {"S": key},
},
)

item = response.get("Item") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
item = response.get("Item")
if not item:
return None

json_value = item.get("value", {}).get("S") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
json_value = item.get("value", {}).get("S")
if not json_value:
return None

return ManagedEntry.from_json(json_str=json_value) # pyright: ignore[reportUnknownArgumentType]
managed_entry: ManagedEntry = self._serialization_adapter.load_json(json_str=json_value)

expires_at_epoch = item.get("ttl", {}).get("N")

# Our managed entry may carry a TTL, but the TTL in DynamoDB takes precedence.
if expires_at_epoch:
managed_entry.expires_at = datetime.fromtimestamp(int(expires_at_epoch), tz=timezone.utc)

return managed_entry

@override
async def _put_managed_entry(
Expand All @@ -210,7 +219,7 @@ async def _put_managed_entry(
managed_entry: ManagedEntry,
) -> None:
"""Store a managed entry in DynamoDB."""
json_value = managed_entry.to_json()
json_value = self._serialization_adapter.dump_json(entry=managed_entry)

item: dict[str, Any] = {
"collection": {"S": collection},
Expand All @@ -219,9 +228,9 @@ async def _put_managed_entry(
}

# Add TTL if present
if managed_entry.ttl is not None and managed_entry.created_at is not None:
if managed_entry.expires_at is not None:
# DynamoDB TTL expects a Unix timestamp
ttl_timestamp = int(managed_entry.created_at.timestamp() + managed_entry.ttl)
ttl_timestamp = int(managed_entry.expires_at.timestamp())
item["ttl"] = {"N": str(ttl_timestamp)}

await self._connected_client.put_item( # pyright: ignore[reportUnknownMemberType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from elastic_transport import ObjectApiResponse
from elastic_transport import SerializationError as ElasticsearchSerializationError
from key_value.shared.errors import DeserializationError, SerializationError
from key_value.shared.utils.managed_entry import ManagedEntry, load_from_json, verify_dict
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.sanitize import (
ALPHANUMERIC_CHARACTERS,
LOWERCASE_ALPHABET,
NUMBERS,
sanitize_string,
)
from key_value.shared.utils.time_to_live import now_as_epoch, try_parse_datetime_str
from key_value.shared.utils.serialization import SerializationAdapter
from key_value.shared.utils.time_to_live import now_as_epoch
from typing_extensions import override

from key_value.aio.stores.base import (
Expand Down Expand Up @@ -84,52 +85,50 @@
ALLOWED_INDEX_CHARACTERS: str = LOWERCASE_ALPHABET + NUMBERS + "_" + "-" + "."


def managed_entry_to_document(collection: str, key: str, managed_entry: ManagedEntry, *, native_storage: bool = False) -> dict[str, Any]:
document: dict[str, Any] = {"collection": collection, "key": key, "value": {}}
class ElasticsearchSerializationAdapter(SerializationAdapter):
"""Adapter for Elasticsearch with support for native and string storage modes."""

# Store in appropriate field based on mode
if native_storage:
document["value"]["flattened"] = managed_entry.value_as_dict
else:
document["value"]["string"] = managed_entry.value_as_json
_native_storage: bool

if managed_entry.created_at:
document["created_at"] = managed_entry.created_at.isoformat()
if managed_entry.expires_at:
document["expires_at"] = managed_entry.expires_at.isoformat()
def __init__(self, *, native_storage: bool = True) -> None:
"""Initialize the Elasticsearch adapter.

return document
Args:
native_storage: If True (default), store values as flattened dicts.
If False, store values as JSON strings.
"""
super().__init__()

self._native_storage = native_storage
self._date_format = "isoformat"
self._value_format = "dict" if native_storage else "string"

def source_to_managed_entry(source: dict[str, Any]) -> ManagedEntry:
value: dict[str, Any] = {}
@override
def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
value = data.pop("value")

data["value"] = {}

raw_value = source.get("value")
if self._native_storage:
data["value"]["flattened"] = value
else:
data["value"]["string"] = value

# Try flattened field first, fall back to string field
if not raw_value or not isinstance(raw_value, dict):
msg = "Value field not found or invalid type"
raise DeserializationError(msg)
return data

if value_flattened := raw_value.get("flattened"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
value = verify_dict(obj=value_flattened)
elif value_str := raw_value.get("string"): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
if not isinstance(value_str, str):
msg = "Value in `value` field is not a string"
raise DeserializationError(msg)
value = load_from_json(value_str)
else:
msg = "Value field not found or invalid type"
raise DeserializationError(msg)
@override
def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]:
value = data.pop("value")

created_at: datetime | None = try_parse_datetime_str(value=source.get("created_at"))
expires_at: datetime | None = try_parse_datetime_str(value=source.get("expires_at"))
if "flattened" in value:
data["value"] = value["flattened"]
elif "string" in value:
data["value"] = value["string"]
else:
msg = "Value field not found in Elasticsearch document"
raise DeserializationError(message=msg)

return ManagedEntry(
value=value,
created_at=created_at,
expires_at=expires_at,
)
return data


class ElasticsearchStore(
Expand All @@ -145,6 +144,8 @@ class ElasticsearchStore(

_native_storage: bool

_adapter: SerializationAdapter

@overload
def __init__(
self,
Expand Down Expand Up @@ -208,6 +209,7 @@ def __init__(
self._index_prefix = index_prefix
self._native_storage = native_storage
self._is_serverless = False
self._adapter = ElasticsearchSerializationAdapter(native_storage=native_storage)

super().__init__(default_collection=default_collection)

Expand Down Expand Up @@ -260,7 +262,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
return None

try:
return source_to_managed_entry(source=source)
return self._adapter.load_dict(data=source)
except DeserializationError:
return None

Expand Down Expand Up @@ -293,7 +295,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
continue

try:
entries_by_id[doc_id] = source_to_managed_entry(source=source)
entries_by_id[doc_id] = self._adapter.load_dict(data=source)
except DeserializationError as e:
logger.error(
"Failed to deserialize Elasticsearch document in batch operation",
Expand Down Expand Up @@ -324,9 +326,7 @@ async def _put_managed_entry(
index_name: str = self._sanitize_index_name(collection=collection)
document_id: str = self._sanitize_document_id(key=key)

document: dict[str, Any] = managed_entry_to_document(
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
)
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)

try:
_ = await self._client.index(
Expand Down Expand Up @@ -364,11 +364,10 @@ async def _put_managed_entries(

index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)

document: dict[str, Any] = managed_entry_to_document(
collection=collection, key=key, managed_entry=managed_entry, native_storage=self._native_storage
)
document: dict[str, Any] = self._adapter.dump_dict(entry=managed_entry)

operations.extend([index_action, document])

try:
_ = await self._client.bulk(operations=operations, refresh=self._should_refresh_on_put) # pyright: ignore[reportUnknownMemberType]
except ElasticsearchSerializationError as e:
Expand Down
Loading