Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, Generic, SupportsFloat, TypeVar, get_origin
from typing import Any, Generic, SupportsFloat, TypeVar, get_origin, overload

from key_value.shared.errors import DeserializationError, SerializationError
from key_value.shared.type_checking.bear_spray import bear_spray
Expand Down Expand Up @@ -71,11 +71,22 @@ def _serialize_model(self, value: T) -> dict[str, Any]:
msg = f"Invalid Pydantic model: {e}"
raise SerializationError(msg) from e

async def get(self, key: str, *, collection: str | None = None) -> T | None:
@overload
async def get(self, key: str, *, collection: str | None = None, default: T) -> T: ...

@overload
async def get(self, key: str, *, collection: str | None = None, default: None = None) -> T | None: ...

async def get(self, key: str, *, collection: str | None = None, default: T | None = None) -> T | None:
"""Get and validate a model by key.

Args:
key: The key to retrieve.
collection: The collection to use. If not provided, uses the default collection.
default: The default value to return if the key doesn't exist or validation fails.

Returns:
The parsed model instance, or None if not present.
The parsed model instance if found and valid, or the default value if key doesn't exist or validation fails.

Raises:
DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to
Expand All @@ -84,15 +95,28 @@ async def get(self, key: str, *, collection: str | None = None) -> T | None:
collection = collection or self._default_collection

if value := await self._key_value.get(key=key, collection=collection):
return self._validate_model(value=value)
validated = self._validate_model(value=value)
if validated is not None:
return validated

return None
return default

async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[T | None]:
@overload
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T) -> list[T]: ...

@overload
async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: None = None) -> list[T | None]: ...

async def get_many(self, keys: Sequence[str], *, collection: str | None = None, default: T | None = None) -> list[T] | list[T | None]:
"""Batch get and validate models by keys, preserving order.

Args:
keys: The list of keys to retrieve.
collection: The collection to use. If not provided, uses the default collection.
default: The default value to return for keys that don't exist or fail validation.

Returns:
A list of parsed model instances, or None if missing.
A list of parsed model instances, with default values for missing keys or validation failures.

Raises:
DeserializationError if the stored data cannot be validated as the model and the PydanticAdapter is configured to
Expand All @@ -102,7 +126,14 @@ async def get_many(self, keys: Sequence[str], *, collection: str | None = None)

values: list[dict[str, Any] | None] = await self._key_value.get_many(keys=keys, collection=collection)

return [self._validate_model(value=value) if value else None for value in values]
result: list[T | None] = []
for value in values:
if value is None:
result.append(default)
else:
validated = self._validate_model(value=value)
result.append(validated if validated is not None else default)
return result

async def put(self, key: str, value: T, *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
"""Serialize and store a model.
Expand Down
31 changes: 29 additions & 2 deletions key-value/key-value-aio/src/key_value/aio/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from asyncio.locks import Lock
from collections import defaultdict
from collections.abc import Mapping, Sequence
from types import TracebackType
from types import MappingProxyType, TracebackType
from typing import Any, SupportsFloat

from key_value.shared.constants import DEFAULT_COLLECTION_NAME
Expand All @@ -24,6 +24,16 @@
AsyncKeyValueProtocol,
)

SEED_DATA_TYPE = Mapping[str, Mapping[str, Mapping[str, Any]]]
FROZEN_SEED_DATA_TYPE = MappingProxyType[str, MappingProxyType[str, MappingProxyType[str, Any]]]
DEFAULT_SEED_DATA: FROZEN_SEED_DATA_TYPE = MappingProxyType({})


def _seed_to_frozen_seed_data(seed: SEED_DATA_TYPE) -> FROZEN_SEED_DATA_TYPE:
return MappingProxyType(
{collection: MappingProxyType({key: MappingProxyType(value) for key, value in items.items()}) for collection, items in seed.items()}
)


class BaseStore(AsyncKeyValueProtocol, ABC):
"""An opinionated Abstract base class for managed key-value stores using ManagedEntry objects.
Expand All @@ -43,21 +53,28 @@ class BaseStore(AsyncKeyValueProtocol, ABC):
_setup_collection_locks: defaultdict[str, Lock]
_setup_collection_complete: defaultdict[str, bool]

_seed: FROZEN_SEED_DATA_TYPE

default_collection: str

def __init__(self, *, default_collection: str | None = None) -> None:
def __init__(self, *, default_collection: str | None = None, seed: SEED_DATA_TYPE | None = None) -> None:
"""Initialize the managed key-value store.

Args:
default_collection: The default collection to use if no collection is provided.
Defaults to "default_collection".
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
Seeding occurs once during store initialization (when the store is first entered or when the
first operation is performed on the store).
"""

self._setup_complete = False
self._setup_lock = Lock()
self._setup_collection_locks = defaultdict(Lock)
self._setup_collection_complete = defaultdict(bool)

self._seed = _seed_to_frozen_seed_data(seed=seed or {})

self.default_collection = default_collection or DEFAULT_COLLECTION_NAME

if not hasattr(self, "_stable_api"):
Expand All @@ -74,6 +91,13 @@ async def _setup(self) -> None:
async def _setup_collection(self, *, collection: str) -> None:
"""Initialize the collection (called once before first use of the collection)."""

async def _seed_store(self) -> None:
"""Seed the store with the data from the seed."""
for collection, items in self._seed.items():
await self.setup_collection(collection=collection)
for key, value in items.items():
await self.put(key=key, value=dict(value), collection=collection)

async def setup(self) -> None:
if not self._setup_complete:
async with self._setup_lock:
Expand All @@ -84,8 +108,11 @@ async def setup(self) -> None:
raise StoreSetupError(
message=f"Failed to setup key value store: {e}", extra_info={"store": self.__class__.__name__}
) from e

self._setup_complete = True

await self._seed_store()

async def setup_collection(self, *, collection: str) -> None:
await self.setup()

Expand Down
30 changes: 27 additions & 3 deletions key-value/key-value-aio/src/key_value/aio/stores/memory/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing_extensions import Self, override

from key_value.aio.stores.base import (
SEED_DATA_TYPE,
BaseDestroyCollectionStore,
BaseDestroyStore,
BaseEnumerateCollectionsStore,
Expand Down Expand Up @@ -109,12 +110,21 @@ class MemoryStore(BaseDestroyStore, BaseDestroyCollectionStore, BaseEnumerateCol

_cache: dict[str, MemoryCollection]

def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION, default_collection: str | None = None):
def __init__(
self,
*,
max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_COLLECTION,
default_collection: str | None = None,
seed: SEED_DATA_TYPE | None = None,
):
"""Initialize a fixed-size in-memory store.

Args:
max_entries_per_collection: The maximum number of entries per collection. Defaults to 10000.
default_collection: The default collection to use if no collection is provided.
seed: Optional seed data to pre-populate the store. Format: {collection: {key: {field: value, ...}}}.
Each value must be a mapping (dict) that will be stored as the entry's value.
Seeding occurs lazily when each collection is first accessed.
"""

self.max_entries_per_collection = max_entries_per_collection
Expand All @@ -123,11 +133,25 @@ def __init__(self, *, max_entries_per_collection: int = DEFAULT_MAX_ENTRIES_PER_

self._stable_api = True

super().__init__(default_collection=default_collection)
super().__init__(default_collection=default_collection, seed=seed)

@override
async def _setup(self) -> None:
for collection in self._seed:
await self._setup_collection(collection=collection)

@override
async def _setup_collection(self, *, collection: str) -> None:
self._cache[collection] = MemoryCollection(max_entries=self.max_entries_per_collection)
"""Set up a collection, creating it and seeding it if seed data is available.

Args:
collection: The collection name.
"""
if collection in self._cache:
return

collection_cache = MemoryCollection(max_entries=self.max_entries_per_collection)
self._cache[collection] = collection_cache

@override
async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Default value wrapper for returning fallback values when keys are not found."""

from key_value.aio.wrappers.default_value.wrapper import DefaultValueWrapper

__all__ = ["DefaultValueWrapper"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from collections.abc import Mapping, Sequence
from typing import Any, SupportsFloat

from key_value.shared.utils.managed_entry import dump_to_json, load_from_json
from typing_extensions import override

from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.wrappers.base import BaseWrapper


class DefaultValueWrapper(BaseWrapper):
"""A wrapper that returns a default value when a key is not found.

This wrapper provides dict.get(key, default) behavior for the key-value store,
allowing you to specify a default value to return instead of None when a key doesn't exist.

It does not store the default value in the underlying key-value store and the TTL returned with the default
value is hard-coded based on the default_ttl parameter. Picking a default_ttl requires careful consideration
of how the value will be used and if any other wrappers will be used that may rely on the TTL.
"""

key_value: AsyncKeyValue # Alias for BaseWrapper compatibility
_default_ttl: float | None
_default_value_json: str

def __init__(
self,
key_value: AsyncKeyValue,
default_value: Mapping[str, Any],
default_ttl: SupportsFloat | None = None,
) -> None:
"""Initialize the DefaultValueWrapper.

Args:
key_value: The underlying key-value store to wrap.
default_value: The default value to return when a key is not found.
default_ttl: The TTL to return to the caller for default values. Defaults to None.
"""
self.key_value = key_value
self._default_value_json = dump_to_json(obj=dict(default_value))
self._default_ttl = None if default_ttl is None else float(default_ttl)

def _new_default_value(self) -> dict[str, Any]:
return load_from_json(json_str=self._default_value_json)

@override
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
result = await self.key_value.get(key=key, collection=collection)
return result if result is not None else self._new_default_value()

@override
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
results = await self.key_value.get_many(keys=keys, collection=collection)
return [result if result is not None else self._new_default_value() for result in results]

@override
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
result, ttl_value = await self.key_value.ttl(key=key, collection=collection)
if result is None:
return (self._new_default_value(), self._default_ttl)
return (result, ttl_value)

@override
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
results = await self.key_value.ttl_many(keys=keys, collection=collection)
return [
(result, ttl_value) if result is not None else (self._new_default_value(), self._default_ttl) for result, ttl_value in results
]
24 changes: 16 additions & 8 deletions key-value/key-value-aio/tests/adapters/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ class Order(BaseModel):
FIXED_UPDATED_AT: datetime = datetime(year=2021, month=1, day=1, hour=15, minute=0, second=0, tzinfo=timezone.utc)

SAMPLE_USER: User = User(name="John Doe", email="john.doe@example.com", age=30)
SAMPLE_USER_2: User = User(name="Jane Doe", email="jane.doe@example.com", age=25)
SAMPLE_PRODUCT: Product = Product(name="Widget", price=29.99, quantity=10, url=AnyHttpUrl(url="https://example.com"))
SAMPLE_ORDER: Order = Order(created_at=datetime.now(), updated_at=datetime.now(), user=SAMPLE_USER, product=SAMPLE_PRODUCT, paid=False)

TEST_COLLECTION: str = "test_collection"
TEST_KEY: str = "test_key"
TEST_KEY_2: str = "test_key_2"


class TestPydanticAdapter:
Expand Down Expand Up @@ -77,8 +79,17 @@ async def test_simple_adapter(self, user_adapter: PydanticAdapter[User]):

assert await user_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY)

cached_user = await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
assert cached_user is None
assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None

async def test_simple_adapter_with_default(self, user_adapter: PydanticAdapter[User]):
assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER

await user_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_USER_2)
assert await user_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY, default=SAMPLE_USER) == SAMPLE_USER_2

assert await user_adapter.get_many(collection=TEST_COLLECTION, keys=[TEST_KEY, TEST_KEY_2], default=SAMPLE_USER) == snapshot(
[SAMPLE_USER_2, SAMPLE_USER]
)

async def test_simple_adapter_with_validation_error_ignore(
self, user_adapter: PydanticAdapter[User], updated_user_adapter: PydanticAdapter[UpdatedUser]
Expand All @@ -98,12 +109,10 @@ async def test_simple_adapter_with_validation_error_raise(

async def test_complex_adapter(self, order_adapter: PydanticAdapter[Order]):
await order_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=SAMPLE_ORDER, ttl=10)
cached_order: Order | None = await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
assert cached_order == SAMPLE_ORDER
assert await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) == SAMPLE_ORDER

assert await order_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY)
cached_order = await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
assert cached_order is None
assert await order_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None

async def test_complex_adapter_with_list(self, product_list_adapter: PydanticAdapter[list[Product]], store: MemoryStore):
await product_list_adapter.put(collection=TEST_COLLECTION, key=TEST_KEY, value=[SAMPLE_PRODUCT, SAMPLE_PRODUCT], ttl=10)
Expand All @@ -127,5 +136,4 @@ async def test_complex_adapter_with_list(self, product_list_adapter: PydanticAda
)

assert await product_list_adapter.delete(collection=TEST_COLLECTION, key=TEST_KEY)
cached_products = await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY)
assert cached_products is None
assert await product_list_adapter.get(collection=TEST_COLLECTION, key=TEST_KEY) is None
4 changes: 4 additions & 0 deletions key-value/key-value-aio/tests/stores/memory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ class TestMemoryStore(BaseStoreTests):
@pytest.fixture
async def store(self) -> MemoryStore:
return MemoryStore(max_entries_per_collection=500)

async def test_seed(self):
store = MemoryStore(max_entries_per_collection=500, seed={"test_collection": {"test_key": {"obj_key": "obj_value"}}})
assert await store.get(key="test_key", collection="test_collection") == {"obj_key": "obj_value"}
Loading