Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
_ = self._cache[collection].set(key=key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)
_ = self._cache[collection].set(
key=key,
value=self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection),
expire=managed_entry.ttl,
)

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ async def _put_managed_entry(
) -> None:
combo_key: str = compound_key(collection=collection, key=key)

_ = self._cache.set(key=combo_key, value=self._serialization_adapter.dump_json(entry=managed_entry), expire=managed_entry.ttl)
_ = self._cache.set(
key=combo_key,
value=self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection),
expire=managed_entry.ttl,
)

@override
async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def _put_managed_entry(
managed_entry: ManagedEntry,
) -> None:
"""Store a managed entry in DynamoDB."""
json_value = self._serialization_adapter.dump_json(entry=managed_entry)
json_value = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)

item: dict[str, Any] = {
"collection": {"S": collection},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
"key": {
"type": "keyword",
},
"version": {
"type": "integer",
},
"value": {
"properties": {
"flattened": {
Expand Down Expand Up @@ -357,7 +360,7 @@ async def _put_managed_entry(
index_name: str = self._get_index_name(collection=collection)
document_id: str = self._get_document_id(key=key)

document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry)
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry, key=key, collection=collection)

try:
_ = await self._client.index(
Expand Down Expand Up @@ -395,7 +398,7 @@ async def _put_managed_entries(

index_action: dict[str, Any] = new_bulk_action(action="index", index=index_name, document_id=document_id)

document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry)
document: dict[str, Any] = self._serializer.dump_dict(entry=managed_entry, key=key, collection=collection)

operations.extend([index_action, document])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Python keyring-based key-value store."""

import os

from key_value.shared.errors.key_value import ValueTooLargeError
from key_value.shared.utils.compound import compound_key
from key_value.shared.utils.managed_entry import ManagedEntry
from key_value.shared.utils.sanitization import HybridSanitizationStrategy, SanitizationStrategy
Expand All @@ -17,6 +20,16 @@

DEFAULT_KEYCHAIN_SERVICE = "py-key-value"


def is_value_too_large(value: bytes) -> bool:
value_length = len(value)
if os.name == "nt":
return value_length > WINDOWS_MAX_VALUE_LENGTH
return False


WINDOWS_MAX_VALUE_LENGTH = 2560 # bytes

MAX_KEY_COLLECTION_LENGTH = 256
ALLOWED_KEY_COLLECTION_CHARACTERS: str = ALPHANUMERIC_CHARACTERS

Expand Down Expand Up @@ -105,7 +118,11 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry:

combo_key: str = compound_key(collection=sanitized_collection, key=sanitized_key)

json_str: str = self._serialization_adapter.dump_json(entry=managed_entry)
json_str: str = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)
encoded_json_bytes: bytes = json_str.encode(encoding="utf-8")

if is_value_too_large(value=encoded_json_bytes):
raise ValueTooLargeError(size=len(encoded_json_bytes), max_size=2560, collection=sanitized_collection, key=sanitized_key)

keyring.set_password(service_name=self._service_name, username=combo_key, password=json_str)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def _put_managed_entry(
else:
exptime = max(int(managed_entry.ttl), 1)

json_value: str = self._serialization_adapter.dump_json(entry=managed_entry)
json_value: str = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)

_ = await self._client.set(
key=combo_key.encode(encoding="utf-8"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def _setup_collection(self, *, collection: str) -> None:
# Ensure index on the unique combo key and supporting queries
sanitized_collection = self._sanitize_collection(collection=collection)

collection_filter: dict[str, str] = {"name": collection}
collection_filter: dict[str, str] = {"name": sanitized_collection}
matching_collections: list[str] = await self._db.list_collection_names(filter=collection_filter)

if matching_collections:
Expand Down Expand Up @@ -273,7 +273,7 @@ async def _put_managed_entry(
collection: str,
managed_entry: ManagedEntry,
) -> None:
mongo_doc = self._adapter.dump_dict(entry=managed_entry)
mongo_doc = self._adapter.dump_dict(entry=managed_entry, key=key, collection=collection)

try:
# Ensure that the value is serializable to JSON
Expand Down Expand Up @@ -308,7 +308,7 @@ async def _put_managed_entries(

operations: list[UpdateOne] = []
for key, managed_entry in zip(keys, managed_entries, strict=True):
mongo_doc = self._adapter.dump_dict(entry=managed_entry)
mongo_doc = self._adapter.dump_dict(entry=managed_entry, key=key, collection=collection)

operations.append(
UpdateOne(
Expand Down Expand Up @@ -346,8 +346,7 @@ async def _delete_collection(self, *, collection: str) -> bool:

_ = await self._db.drop_collection(name_or_collection=collection_name)

if collection_name in self._collections_by_name:
del self._collections_by_name[collection]
self._collections_by_name.pop(collection, None)

return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def _put_managed_entry(
) -> None:
combo_key: str = compound_key(collection=collection, key=key)

json_value: str = self._adapter.dump_json(entry=managed_entry)
json_value: str = self._adapter.dump_json(entry=managed_entry, key=key, collection=collection)

if managed_entry.ttl is not None:
# Redis does not support <= 0 TTLs
Expand Down Expand Up @@ -160,7 +160,7 @@ async def _put_managed_entries(
# If there is no TTL, we can just do a simple mset
mapping: dict[str, str] = {}
for key, managed_entry in zip(keys, managed_entries, strict=True):
json_value = self._adapter.dump_json(entry=managed_entry)
json_value = self._adapter.dump_json(entry=managed_entry, key=key, collection=collection)
mapping[compound_key(collection=collection, key=key)] = json_value

await self._client.mset(mapping=mapping)
Expand All @@ -175,7 +175,7 @@ async def _put_managed_entries(

for key, managed_entry in zip(keys, managed_entries, strict=True):
combo_key: str = compound_key(collection=collection, key=key)
json_value = self._adapter.dump_json(entry=managed_entry)
json_value = self._adapter.dump_json(entry=managed_entry, key=key, collection=collection)

pipeline.setex(name=combo_key, time=ttl_seconds, value=json_value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def _put_managed_entry(
self._fail_on_closed_store()

combo_key: str = compound_key(collection=collection, key=key)
json_value: str = self._serialization_adapter.dump_json(entry=managed_entry)
json_value: str = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)

self._db[combo_key] = json_value.encode("utf-8")

Expand All @@ -147,7 +147,7 @@ async def _put_managed_entries(
batch = WriteBatch()
for key, managed_entry in zip(keys, managed_entries, strict=True):
combo_key: str = compound_key(collection=collection, key=key)
json_value: str = self._serialization_adapter.dump_json(entry=managed_entry)
json_value: str = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)
batch.put(combo_key, json_value.encode("utf-8"))

self._db.write(batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry:
_ = self._data.pop(next(iter(self._data)))

self._data[combo_key] = SimpleStoreEntry(
json_str=self._serialization_adapter.dump_json(entry=managed_entry),
json_str=self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection),
expires_at=managed_entry.expires_at,
created_at=managed_entry.created_at,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def _put_managed_entry(
) -> None:
combo_key: str = compound_key(collection=collection, key=key)

json_value: str = self._serialization_adapter.dump_json(entry=managed_entry)
json_value: str = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)

expiry: ExpirySet | None = ExpirySet(expiry_type=ExpiryType.SEC, value=int(managed_entry.ttl)) if managed_entry.ttl else None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
async def _put_managed_entry(self, *, key: str, collection: str, managed_entry: ManagedEntry) -> None:
combo_key: str = compound_key(collection=collection, key=key)

json_str: str = self._serialization_adapter.dump_json(entry=managed_entry)
json_str: str = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)

# Store the JSON string in a 'value' field
secret_data = {"value": json_str}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def _put_managed_entry(self, *, key: str, collection: str, managed_entry:
sanitized_key = self._sanitize_key(key=key)
registry_path = self._get_registry_path(collection=collection)

json_str: str = self._serialization_adapter.dump_json(entry=managed_entry)
json_str: str = self._serialization_adapter.dump_json(entry=managed_entry, key=key, collection=collection)

set_reg_sz_value(hive=self._hive, sub_key=registry_path, value_name=sanitized_key, value=json_str)

Expand Down
19 changes: 17 additions & 2 deletions key-value/key-value-aio/tests/stores/disk/test_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,27 @@ async def test_value_stored(self, store: DiskStore, disk_cache: Cache):

value = disk_cache.get(key="test::test_key")
value_as_dict = json.loads(value)
assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}})
assert value_as_dict == snapshot(
{
"collection": "test",
"created_at": IsDatetime(iso_string=True),
"key": "test_key",
"value": {"age": 30, "name": "Alice"},
"version": 1,
}
)

await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10)

value = disk_cache.get(key="test::test_key")
value_as_dict = json.loads(value)
assert value_as_dict == snapshot(
{"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)}
{
"collection": "test",
"created_at": IsDatetime(iso_string=True),
"value": {"age": 30, "name": "Alice"},
"key": "test_key",
"expires_at": IsDatetime(iso_string=True),
"version": 1,
}
)
12 changes: 11 additions & 1 deletion key-value/key-value-aio/tests/stores/disk/test_multi_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ async def test_value_stored(self, store: MultiDiskStore):
value_as_dict = json.loads(value)
assert value_as_dict == snapshot(
{
"collection": "test",
"value": {"name": "Alice", "age": 30},
"key": "test_key",
"created_at": IsDatetime(iso_string=True),
"version": 1,
}
)

Expand All @@ -50,5 +53,12 @@ async def test_value_stored(self, store: MultiDiskStore):
value = disk_cache.get(key="test_key")
value_as_dict = json.loads(value)
assert value_as_dict == snapshot(
{"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)}
{
"collection": "test",
"created_at": IsDatetime(iso_string=True),
"value": {"age": 30, "name": "Alice"},
"key": "test_key",
"expires_at": IsDatetime(iso_string=True),
"version": 1,
}
)
17 changes: 15 additions & 2 deletions key-value/key-value-aio/tests/stores/dynamodb/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,13 @@ async def test_value_stored(self, store: DynamoDBStore):
TableName=DYNAMODB_TEST_TABLE, Key={"collection": {"S": "test"}, "key": {"S": "test_key"}}
)
assert get_value_from_response(response=response) == snapshot(
{"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}}
{
"collection": "test",
"created_at": IsDatetime(iso_string=True),
"key": "test_key",
"value": {"age": 30, "name": "Alice"},
"version": 1,
}
)

assert "ttl" not in response.get("Item", {})
Expand All @@ -136,7 +142,14 @@ async def test_value_stored(self, store: DynamoDBStore):
TableName=DYNAMODB_TEST_TABLE, Key={"collection": {"S": "test"}, "key": {"S": "test_key"}}
)
assert get_value_from_response(response=response) == snapshot(
{"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}, "expires_at": IsDatetime(iso_string=True)}
{
"collection": "test",
"created_at": IsDatetime(iso_string=True),
"value": {"age": 30, "name": "Alice"},
"key": "test_key",
"expires_at": IsDatetime(iso_string=True),
"version": 1,
}
)
# Verify DynamoDB TTL attribute is set for automatic expiration
assert "ttl" in response.get("Item", {}), "DynamoDB TTL attribute should be set when ttl parameter is provided"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_managed_entry_document_conversion():

assert document == snapshot(
{
"version": 1,
"value": {"flattened": {"test": "test"}},
"created_at": "2025-01-01T00:00:00+00:00",
"expires_at": "2025-01-01T00:00:10+00:00",
Expand Down Expand Up @@ -174,6 +175,9 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore,
response = await es_client.get(index=index_name, id=doc_id)
assert response.body["_source"] == snapshot(
{
"version": 1,
"key": "test_key",
"collection": "test",
"value": {"flattened": {"name": "Alice", "age": 30}},
"created_at": IsStr(min_length=20, max_length=40),
}
Expand All @@ -184,6 +188,9 @@ async def test_value_stored_as_flattened_object(self, store: ElasticsearchStore,
response = await es_client.get(index=index_name, id=doc_id)
assert response.body["_source"] == snapshot(
{
"version": 1,
"key": "test_key",
"collection": "test",
"value": {"flattened": {"name": "Bob", "age": 25}},
"created_at": IsStr(min_length=20, max_length=40),
"expires_at": IsStr(min_length=20, max_length=40),
Expand Down
5 changes: 3 additions & 2 deletions key-value/key-value-aio/tests/stores/keyring/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def test_get_large_put_get(self, store: BaseStore, data: dict[str, Any], j


@pytest.mark.skipif(condition=not detect_on_macos(), reason="Keyrings do not support large values on MacOS")
@pytest.mark.filterwarnings("ignore:A configured store is unstable and may change in a backwards incompatible way. Use at your own risk.")
class TestMacOSKeychainStore(BaseTestKeychainStore):
pass

Expand All @@ -54,8 +55,8 @@ async def test_long_collection_name(self, store: KeyringStore, sanitizing_store:
with pytest.raises(Exception): # noqa: B017, PT011
await store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"})

await sanitizing_store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"})
assert await sanitizing_store.get(collection="test_collection" * 100, key="test_key") == {"test": "test"}
await sanitizing_store.put(collection="test_collection" * 50, key="test_key", value={"test": "test"})
assert await sanitizing_store.get(collection="test_collection" * 50, key="test_key") == {"test": "test"}

@override
async def test_long_key_name(self, store: KeyringStore, sanitizing_store: KeyringStore): # pyright: ignore[reportIncompatibleMethodOverride]
Expand Down
13 changes: 12 additions & 1 deletion key-value/key-value-aio/tests/stores/memcached/test_memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,15 @@ async def test_value_stored(self, store: MemcachedStore, memcached_client: Clien
value = await memcached_client.get(key=b"test::test_key")
assert value is not None
value_as_dict = json.loads(value.decode("utf-8"))
assert value_as_dict == snapshot({"created_at": IsDatetime(iso_string=True), "value": {"age": 30, "name": "Alice"}})
assert value_as_dict == snapshot(
{
"collection": "test",
"created_at": IsDatetime(iso_string=True),
"key": "test_key",
"value": {"age": 30, "name": "Alice"},
"version": 1,
}
)

await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30}, ttl=10)

Expand All @@ -116,8 +124,11 @@ async def test_value_stored(self, store: MemcachedStore, memcached_client: Clien
value_as_dict = json.loads(value.decode("utf-8"))
assert value_as_dict == snapshot(
{
"collection": "test",
"created_at": IsDatetime(iso_string=True),
"expires_at": IsDatetime(iso_string=True),
"key": "test_key",
"value": {"age": 30, "name": "Alice"},
"version": 1,
}
)
Loading