Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 7 additions & 23 deletions key-value/key-value-aio/src/key_value/aio/stores/mongodb/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,20 @@


class MongoDBSerializationAdapter(SerializationAdapter):
"""Adapter for MongoDB with support for native and string storage modes."""
"""Adapter for MongoDB with native BSON storage."""

_native_storage: bool

def __init__(self, *, native_storage: bool = True) -> None:
def __init__(self) -> None:
"""Initialize the MongoDB adapter."""
super().__init__()

self._native_storage = native_storage
self._date_format = "datetime"
self._value_format = "dict" if native_storage else "string"
self._value_format = "dict"

@override
def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
value = data.pop("value")

data["value"] = {}

if self._native_storage:
data["value"]["object"] = value
else:
data["value"]["string"] = value
data["value"] = {"object": value}

return data

Expand All @@ -69,8 +61,6 @@ def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]:

if "object" in value:
data["value"] = value["object"]
elif "string" in value:
data["value"] = value["string"]
else:
msg = "Value field not found in MongoDB document"
raise DeserializationError(message=msg)
Expand Down Expand Up @@ -121,7 +111,6 @@ def __init__(
client: AsyncMongoClient[dict[str, Any]],
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
collection_sanitization_strategy: SanitizationStrategy | None = None,
) -> None:
Expand All @@ -131,7 +120,6 @@ def __init__(
client: The MongoDB client to use.
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
default_collection: The default collection to use if no collection is provided.
collection_sanitization_strategy: The sanitization strategy to use for collections.
"""
Expand All @@ -143,7 +131,6 @@ def __init__(
url: str,
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
collection_sanitization_strategy: SanitizationStrategy | None = None,
) -> None:
Expand All @@ -153,7 +140,6 @@ def __init__(
url: The url of the MongoDB cluster.
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
default_collection: The default collection to use if no collection is provided.
collection_sanitization_strategy: The sanitization strategy to use for collections.
"""
Expand All @@ -165,20 +151,18 @@ def __init__(
url: str | None = None,
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
collection_sanitization_strategy: SanitizationStrategy | None = None,
) -> None:
"""Initialize the MongoDB store.

Values are stored as native BSON dictionaries for better query support and performance.

Args:
client: The MongoDB client to use (mutually exclusive with url).
url: The url of the MongoDB cluster (mutually exclusive with client).
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
Native storage stores values as BSON dicts for better query support.
Legacy mode stores values as JSON strings for backward compatibility.
default_collection: The default collection to use if no collection is provided.
collection_sanitization_strategy: The sanitization strategy to use for collections.
"""
Expand All @@ -196,7 +180,7 @@ def __init__(

self._db = self._client[db_name]
self._collections_by_name = {}
self._adapter = MongoDBSerializationAdapter(native_storage=native_storage)
self._adapter = MongoDBSerializationAdapter()

super().__init__(
default_collection=default_collection,
Expand Down
119 changes: 7 additions & 112 deletions key-value/key-value-aio/tests/stores/mongodb/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ class MongoDBFailedToStartError(Exception):
pass


def test_managed_entry_document_conversion_native_mode():
def test_managed_entry_document_conversion():
"""Test that documents are stored as BSON dicts."""
created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
expires_at = created_at + timedelta(seconds=10)

managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at)

adapter = MongoDBSerializationAdapter(native_storage=True)
adapter = MongoDBSerializationAdapter()
document = adapter.dump_dict(entry=managed_entry)

assert document == snapshot(
Expand All @@ -74,31 +75,6 @@ def test_managed_entry_document_conversion_native_mode():
assert round_trip_managed_entry.expires_at == expires_at


def test_managed_entry_document_conversion_legacy_mode():
created_at = datetime(year=2025, month=1, day=1, hour=0, minute=0, second=0, tzinfo=timezone.utc)
expires_at = created_at + timedelta(seconds=10)

managed_entry = ManagedEntry(value={"test": "test"}, created_at=created_at, expires_at=expires_at)
adapter = MongoDBSerializationAdapter(native_storage=False)
document = adapter.dump_dict(entry=managed_entry)

assert document == snapshot(
{
"version": 1,
"value": {"string": '{"test": "test"}'},
"created_at": datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc),
"expires_at": datetime(2025, 1, 1, 0, 0, 10, tzinfo=timezone.utc),
}
)

round_trip_managed_entry = adapter.load_dict(data=document)

assert round_trip_managed_entry.value == managed_entry.value
assert round_trip_managed_entry.created_at == created_at
assert round_trip_managed_entry.ttl == IsFloat(lt=0)
assert round_trip_managed_entry.expires_at == expires_at


async def clean_mongodb_database(store: MongoDBStore) -> None:
with contextlib.suppress(Exception):
_ = await store._client.drop_database(name_or_database=store._db.name) # pyright: ignore[reportPrivateUsage]
Expand Down Expand Up @@ -151,13 +127,13 @@ async def test_mongodb_collection_name_sanitization(self, sanitizing_store: Mong


@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available")
class TestMongoDBStoreNativeMode(BaseMongoDBStoreTests):
"""Test MongoDBStore with native_storage=True (default)."""
class TestMongoDBStore(BaseMongoDBStoreTests):
"""Test MongoDBStore with native BSON storage."""

@override
@pytest.fixture
async def store(self, setup_mongodb: None) -> MongoDBStore:
store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=f"{MONGODB_TEST_DB}-native", native_storage=True)
store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB)

await clean_mongodb_database(store=store)

Expand All @@ -167,8 +143,7 @@ async def store(self, setup_mongodb: None) -> MongoDBStore:
async def sanitizing_store(self, setup_mongodb: None) -> MongoDBStore:
store = MongoDBStore(
url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}",
db_name=f"{MONGODB_TEST_DB}-native-sanitizing",
native_storage=True,
db_name=f"{MONGODB_TEST_DB}-sanitizing",
collection_sanitization_strategy=MongoDBV1CollectionSanitizationStrategy(),
)

Expand Down Expand Up @@ -196,83 +171,3 @@ async def test_value_stored_as_bson_dict(self, store: MongoDBStore):
"version": 1,
}
)

async def test_migration_from_legacy_mode(self, store: MongoDBStore):
"""Verify native mode can read legacy JSON string data."""
await store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage]
sanitized_collection = store._sanitize_collection(collection="test") # pyright: ignore[reportPrivateUsage]
collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage]

await collection.insert_one(
{
"key": "legacy_key",
"value": {"string": '{"legacy": "data"}'},
}
)

result = await store.get(collection="test", key="legacy_key")
assert result == {"legacy": "data"}


@pytest.mark.skipif(should_skip_docker_tests(), reason="Docker is not available")
class TestMongoDBStoreNonNativeMode(BaseMongoDBStoreTests):
"""Test MongoDBStore with native_storage=False (legacy mode) for backward compatibility."""

@override
@pytest.fixture
async def store(self, setup_mongodb: None) -> MongoDBStore:
store = MongoDBStore(url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}", db_name=MONGODB_TEST_DB, native_storage=False)

await clean_mongodb_database(store=store)

return store

@pytest.fixture
async def sanitizing_store(self, setup_mongodb: None) -> MongoDBStore:
store = MongoDBStore(
url=f"mongodb://{MONGODB_HOST}:{MONGODB_HOST_PORT}",
db_name=f"{MONGODB_TEST_DB}-sanitizing",
native_storage=False,
collection_sanitization_strategy=MongoDBV1CollectionSanitizationStrategy(),
)

await clean_mongodb_database(store=store)

return store

async def test_value_stored_as_json(self, store: MongoDBStore):
"""Verify values are stored as JSON strings."""
await store.put(collection="test", key="test_key", value={"name": "Alice", "age": 30})

# Get the raw MongoDB document
await store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage]
sanitized_collection = store._sanitize_collection(collection="test") # pyright: ignore[reportPrivateUsage]
collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage]
doc = await collection.find_one({"key": "test_key"})

assert doc == snapshot(
{
"_id": IsInstance(expected_type=ObjectId),
"key": "test_key",
"collection": "test",
"created_at": IsDatetime(),
"value": {"string": '{"age": 30, "name": "Alice"}'},
"version": 1,
}
)

async def test_migration_from_native_mode(self, store: MongoDBStore):
"""Verify non-native mode can read native mode data."""
await store._setup_collection(collection="test") # pyright: ignore[reportPrivateUsage]
sanitized_collection = store._sanitize_collection(collection="test") # pyright: ignore[reportPrivateUsage]
collection = store._collections_by_name[sanitized_collection] # pyright: ignore[reportPrivateUsage]

await collection.insert_one(
{
"key": "legacy_key",
"value": {"object": {"name": "Alice", "age": 30}},
}
)

result = await store.get(collection="test", key="legacy_key")
assert result == {"name": "Alice", "age": 30}
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,20 @@


class MongoDBSerializationAdapter(SerializationAdapter):
"""Adapter for MongoDB with support for native and string storage modes."""
"""Adapter for MongoDB with native BSON storage."""

_native_storage: bool

def __init__(self, *, native_storage: bool = True) -> None:
def __init__(self) -> None:
"""Initialize the MongoDB adapter."""
super().__init__()

self._native_storage = native_storage
self._date_format = "datetime"
self._value_format = "dict" if native_storage else "string"
self._value_format = "dict"

@override
def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
value = data.pop("value")

data["value"] = {}

if self._native_storage:
data["value"]["object"] = value
else:
data["value"]["string"] = value
data["value"] = {"object": value}

return data

Expand All @@ -71,8 +63,6 @@ def prepare_load(self, data: dict[str, Any]) -> dict[str, Any]:

if "object" in value:
data["value"] = value["object"]
elif "string" in value:
data["value"] = value["string"]
else:
msg = "Value field not found in MongoDB document"
raise DeserializationError(message=msg)
Expand Down Expand Up @@ -119,7 +109,6 @@ def __init__(
client: MongoClient[dict[str, Any]],
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
collection_sanitization_strategy: SanitizationStrategy | None = None,
) -> None:
Expand All @@ -129,7 +118,6 @@ def __init__(
client: The MongoDB client to use.
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
default_collection: The default collection to use if no collection is provided.
collection_sanitization_strategy: The sanitization strategy to use for collections.
"""
Expand All @@ -141,7 +129,6 @@ def __init__(
url: str,
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
collection_sanitization_strategy: SanitizationStrategy | None = None,
) -> None:
Expand All @@ -151,7 +138,6 @@ def __init__(
url: The url of the MongoDB cluster.
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
default_collection: The default collection to use if no collection is provided.
collection_sanitization_strategy: The sanitization strategy to use for collections.
"""
Expand All @@ -163,20 +149,18 @@ def __init__(
url: str | None = None,
db_name: str | None = None,
coll_name: str | None = None,
native_storage: bool = True,
default_collection: str | None = None,
collection_sanitization_strategy: SanitizationStrategy | None = None,
) -> None:
"""Initialize the MongoDB store.

Values are stored as native BSON dictionaries for better query support and performance.

Args:
client: The MongoDB client to use (mutually exclusive with url).
url: The url of the MongoDB cluster (mutually exclusive with client).
db_name: The name of the MongoDB database.
coll_name: The name of the MongoDB collection.
native_storage: Whether to use native BSON storage (True, default) or JSON string storage (False).
Native storage stores values as BSON dicts for better query support.
Legacy mode stores values as JSON strings for backward compatibility.
default_collection: The default collection to use if no collection is provided.
collection_sanitization_strategy: The sanitization strategy to use for collections.
"""
Expand All @@ -194,7 +178,7 @@ def __init__(

self._db = self._client[db_name]
self._collections_by_name = {}
self._adapter = MongoDBSerializationAdapter(native_storage=native_storage)
self._adapter = MongoDBSerializationAdapter()

super().__init__(default_collection=default_collection, collection_sanitization_strategy=collection_sanitization_strategy)

Expand Down
Loading