Skip to content

Commit

Permalink
feat: handle annotated indexes (#762)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrEarle committed Nov 5, 2023
1 parent 82661f1 commit 0add350
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 14 deletions.
25 changes: 23 additions & 2 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import sys
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from typing import (
TYPE_CHECKING,
Expand All @@ -20,6 +21,7 @@
from typing_extensions import get_args

from typing import OrderedDict as OrderedDictType
from typing import Tuple

from bson import DBRef, ObjectId
from bson.errors import InvalidId
Expand Down Expand Up @@ -66,13 +68,32 @@
from beanie.odm.documents import DocType


def Indexed(typ, index_type=ASCENDING, **kwargs):
@dataclass(frozen=True)
class IndexedAnnotation:
_indexed: Tuple[int, Dict[str, Any]]


def Indexed(typ=None, index_type=ASCENDING, **kwargs):
"""
Returns a subclass of `typ` with an extra attribute `_indexed` as a tuple:
If `typ` is defined, returns a subclass of `typ` with an extra attribute
`_indexed` as a tuple:
- Index 0: `index_type` such as `pymongo.ASCENDING`
- Index 1: `kwargs` passed to `IndexModel`
When instantiated the type of the result will actually be `typ`.
When `typ` is not defined, returns an `IndexedAnnotation` instance, to be
used as metadata in `Annotated` fields.
Example:
```py
# Both fields would have the same behavior
class MyModel(BaseModel):
field1: Indexed(str, unique=True)
field2: Annotated[str, Indexed(unique=True)]
```
"""
if typ is None:
return IndexedAnnotation(_indexed=(index_type, kwargs))

class NewType(typ):
_indexed = (index_type, kwargs)
Expand Down
15 changes: 9 additions & 6 deletions beanie/odm/utils/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from beanie.odm.utils.pydantic import (
IS_PYDANTIC_V2,
get_extra_field_info,
get_field_type,
get_model_fields,
parse_model,
)
from beanie.odm.utils.typing import get_index_attributes

if sys.version_info >= (3, 8):
from typing import get_args, get_origin
Expand Down Expand Up @@ -458,21 +458,24 @@ async def init_indexes(self, cls, allow_index_dropping: bool = False):
new_indexes = []

# Indexed field wrapped with Indexed()
indexed_fields = (
(k, fvalue, get_index_attributes(fvalue))
for k, fvalue in get_model_fields(cls).items()
)
found_indexes = [
IndexModelField(
IndexModel(
[
(
fvalue.alias or k,
fvalue.annotation._indexed[0],
indexed_attrs[0],
)
],
**fvalue.annotation._indexed[1],
**indexed_attrs[1],
)
)
for k, fvalue in get_model_fields(cls).items()
if hasattr(get_field_type(fvalue), "_indexed")
and get_field_type(fvalue)._indexed
for k, fvalue, indexed_attrs in indexed_fields
if indexed_attrs is not None
]

if document_settings.merge_indexes:
Expand Down
58 changes: 55 additions & 3 deletions beanie/odm/utils/typing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import inspect
import sys
from typing import Any, Dict, Optional, Tuple, Type

from beanie.odm.fields import IndexedAnnotation

from .pydantic import IS_PYDANTIC_V2, get_field_type

if sys.version_info >= (3, 8):
from typing import get_args, get_origin
else:
from typing_extensions import get_args, get_origin

import inspect
from typing import Any, Type


def extract_id_class(annotation) -> Type[Any]:
if get_origin(annotation) is not None:
Expand All @@ -20,3 +23,52 @@ def extract_id_class(annotation) -> Type[Any]:
if inspect.isclass(annotation):
return annotation
raise ValueError("Unknown annotation: {}".format(annotation))


def get_index_attributes(field) -> Optional[Tuple[int, Dict[str, Any]]]:
"""Gets the index attributes from the field, if it is indexed.
:param field: The field to get the index attributes from.
:return: The index attributes, if the field is indexed. Otherwise, None.
"""
# For fields that are directly typed with `Indexed()`, the type will have
# an `_indexed` attribute.
field_type = get_field_type(field)
if hasattr(field_type, "_indexed"):
return getattr(field_type, "_indexed", None)

# For fields that are use `Indexed` within `Annotated`, the field will have
# metadata that might contain an `IndexedAnnotation` instance.
if IS_PYDANTIC_V2:
# In Pydantic 2, the field has a `metadata` attribute with
# the annotations.
metadata = getattr(field, "metadata", None)
elif hasattr(field, "annotation") and hasattr(
field.annotation, "__metadata__"
):
# In Pydantic 1, the field has an `annotation` attribute with the
# type assigned to the field. If the type is annotated, it will
# have a `__metadata__` attribute with the annotations.
metadata = field.annotation.__metadata__
else:
return None

if metadata is None:
return None

try:
iter(metadata)
except TypeError:
return None

indexed_annotation = next(
(
annotation
for annotation in metadata
if isinstance(annotation, IndexedAnnotation)
),
None,
)

return getattr(indexed_annotation, "_indexed", None)
14 changes: 12 additions & 2 deletions docs/tutorial/indexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ from beanie import Document, Indexed


class Sample(Document):
num: Indexed(int)
num: Annotated[int, Indexed()]
description: str
```

Expand All @@ -25,7 +25,7 @@ from beanie import Document, Indexed


class Sample(Document):
description: Indexed(str, index_type=pymongo.TEXT)
description: Annotated[str, Indexed(index_type=pymongo.TEXT)]
```

The `Indexed` function also supports PyMongo's `IndexModel` kwargs arguments (see the [PyMongo Documentation](https://pymongo.readthedocs.io/en/stable/api/pymongo/operations.html#pymongo.operations.IndexModel) for details).
Expand All @@ -36,6 +36,16 @@ For example, to create a `unique` index:
from beanie import Document, Indexed


class Sample(Document):
name: Annotated[str, Indexed(unique=True)]
```

The `Indexed` function can also be used directly in the type annotation, by giving it the wrapped type as the first argument. Note that this might not work with some Pydantic V2 types, such as `UUID4` or `EmailStr`.

```python
from beanie import Document, Indexed


class Sample(Document):
name: Indexed(str, unique=True)
```
Expand Down
2 changes: 2 additions & 0 deletions tests/odm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DocumentMultiModelTwo,
DocumentTestModel,
DocumentTestModelFailInspection,
DocumentTestModelIndexFlagsAnnotated,
DocumentTestModelWithComplexIndex,
DocumentTestModelWithCustomCollectionName,
DocumentTestModelWithIndexFlags,
Expand Down Expand Up @@ -197,6 +198,7 @@ async def init(db):
DocumentTestModelWithSimpleIndex,
DocumentTestModelWithIndexFlags,
DocumentTestModelWithIndexFlagsAliases,
DocumentTestModelIndexFlagsAnnotated,
DocumentTestModelWithComplexIndex,
DocumentTestModelFailInspection,
DocumentWithBsonEncodersFiledsTypes,
Expand Down
28 changes: 28 additions & 0 deletions tests/odm/documents/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.odm.models import (
Color,
DocumentTestModel,
DocumentTestModelIndexFlagsAnnotated,
DocumentTestModelStringImport,
DocumentTestModelWithComplexIndex,
DocumentTestModelWithCustomCollectionName,
Expand Down Expand Up @@ -112,6 +113,33 @@ async def test_flagged_index_creation_with_alias():
}


async def test_annotated_index_creation():
collection: AsyncIOMotorCollection = (
DocumentTestModelIndexFlagsAnnotated.get_motor_collection()
)
index_info = await collection.index_information()
assert index_info["str_index_text"]["key"] == [
("_fts", "text"),
("_ftsx", 1),
]
assert index_info["str_index_annotated_1"] == {
"key": [("str_index_annotated", 1)],
"v": 2,
}

assert index_info["uuid_index_annotated_1"] == {
"key": [("uuid_index_annotated", 1)],
"unique": True,
"v": 2,
}
if "uuid_index" in index_info:
assert index_info["uuid_index"] == {
"key": [("uuid_index", 1)],
"unique": True,
"v": 2,
}


async def test_complex_index_creation():
collection: AsyncIOMotorCollection = (
DocumentTestModelWithComplexIndex.get_motor_collection()
Expand Down
13 changes: 13 additions & 0 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pymongo
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
Expand All @@ -35,6 +36,7 @@
from pydantic.fields import FieldInfo
from pydantic_core import core_schema
from pymongo import IndexModel
from typing_extensions import Annotated

from beanie import (
DecimalAnnotation,
Expand Down Expand Up @@ -193,6 +195,17 @@ class DocumentTestModelWithIndexFlagsAliases(Document):
)


class DocumentTestModelIndexFlagsAnnotated(Document):
str_index: Indexed(str, index_type=pymongo.TEXT)
str_index_annotated: Indexed(str, index_type=pymongo.ASCENDING)
uuid_index_annotated: Annotated[UUID4, Indexed(unique=True)]

if not IS_PYDANTIC_V2:
# The UUID4 type raises a ValueError with the current
# implementation of Indexed when using Pydantic v2.
uuid_index: Indexed(UUID4, unique=True)


class DocumentTestModelWithComplexIndex(Document):
test_int: int
test_list: List[SubDocument]
Expand Down
16 changes: 16 additions & 0 deletions tests/odm/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from decimal import Decimal
from pathlib import Path
from typing import AbstractSet, Mapping
from uuid import uuid4

import pytest
from pydantic import BaseModel, ValidationError
Expand All @@ -14,6 +15,7 @@
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2
from tests.odm.models import (
DocumentTestModel,
DocumentTestModelIndexFlagsAnnotated,
DocumentWithBsonEncodersFiledsTypes,
DocumentWithCustomFiledsTypes,
DocumentWithDeprecatedHiddenField,
Expand Down Expand Up @@ -167,3 +169,17 @@ async def test_param_exclude(document, exclude):
def test_expression_fields():
assert Sample.nested.integer == "nested.integer"
assert Sample.nested["integer"] == "nested.integer"


def test_indexed_field() -> None:
"""Test that fields can be declared and instantiated with Indexed()
and Annotated[..., Indexed()]."""

# No error should be raised the document is properly initialized
# and `Indexed` is implemented correctly.
DocumentTestModelIndexFlagsAnnotated(
str_index="test",
str_index_annotated="test",
uuid_index=uuid4(),
uuid_index_annotated=uuid4(),
)
29 changes: 28 additions & 1 deletion tests/odm/test_typing_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Optional, Union

import pytest
from pydantic import BaseModel
from typing_extensions import Annotated

from beanie import Document, Link
from beanie.odm.utils.typing import extract_id_class
from beanie.odm.fields import Indexed
from beanie.odm.utils.pydantic import get_model_fields
from beanie.odm.utils.typing import extract_id_class, get_index_attributes


class Lock(Document):
Expand All @@ -18,3 +24,24 @@ def test_extract_id_class(self):
assert extract_id_class(Optional[str]) == str
# Link
assert extract_id_class(Link[Lock]) == Lock

@pytest.mark.parametrize(
"type,result",
(
(str, None),
(Indexed(str), (1, {})),
(Indexed(str, "text", unique=True), ("text", {"unique": True})),
(Annotated[str, Indexed()], (1, {})),
(
Annotated[str, "other metadata", Indexed(unique=True)],
(1, {"unique": True}),
),
(Annotated[str, "other metadata"], None),
),
)
def test_get_index_attributes(self, type, result):
class Foo(BaseModel):
bar: type

field = get_model_fields(Foo)["bar"]
assert get_index_attributes(field) == result

0 comments on commit 0add350

Please sign in to comment.