Skip to content

Commit

Permalink
fix id notation (#772)
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-right committed Nov 9, 2023
1 parent dbfc729 commit 5fcf24c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
13 changes: 5 additions & 8 deletions beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,22 @@ def __init__(self, document_model: Type["DocType"]):
self.pymongo_kwargs: Dict[str, Any] = {}
self.lazy_parse = False

def prepare_find_expressions(self, for_aggregation: bool = False):
for_aggregation = for_aggregation or self.fetch_links
def prepare_find_expressions(self):
if self.document_model.get_link_fields() is not None:
for i, query in enumerate(self.find_expressions):
self.find_expressions[i] = convert_ids(
query,
doc=self.document_model, # type: ignore
for_aggregation=for_aggregation,
fetch_links=self.fetch_links,
)

def get_filter_query(
self, for_aggregation: bool = False
) -> Mapping[str, Any]:
def get_filter_query(self) -> Mapping[str, Any]:
"""
Returns: MongoDB filter query
"""
self.prepare_find_expressions(for_aggregation=for_aggregation)
self.prepare_find_expressions()
if self.find_expressions:
return Encoder(custom_encoders=self.encoders).encode(
And(*self.find_expressions).query
Expand Down Expand Up @@ -608,7 +605,7 @@ def build_aggregation_pipeline(self, *extra_stages):
] = construct_lookup_queries(self.document_model)
else:
aggregation_pipeline = []
filter_query = self.get_filter_query(for_aggregation=True)
filter_query = self.get_filter_query()

if filter_query:
text_queries, non_text_queries = split_text_query(filter_query)
Expand Down
8 changes: 4 additions & 4 deletions beanie/odm/utils/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def convert_ids(
query: MappingType[str, Any], doc: "Document", for_aggregation: bool
query: MappingType[str, Any], doc: "Document", fetch_links: bool
) -> Dict[str, Any]:
# TODO add all the cases
new_query = {}
Expand All @@ -27,18 +27,18 @@ def convert_ids(
and k_splitted[0] in doc.get_link_fields().keys() # type: ignore
and k_splitted[1] == "id"
):
if for_aggregation:
if fetch_links:
new_k = f"{k_splitted[0]}._id"
else:
new_k = f"{k_splitted[0]}.$id"
else:
new_k = k
new_v: Any
if isinstance(v, Mapping):
new_v = convert_ids(v, doc, for_aggregation)
new_v = convert_ids(v, doc, fetch_links)
elif isinstance(v, list):
new_v = [
convert_ids(ele, doc, for_aggregation)
convert_ids(ele, doc, fetch_links)
if isinstance(ele, Mapping)
else ele
for ele in v
Expand Down
17 changes: 15 additions & 2 deletions tests/odm/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

from beanie import Document, init_beanie
from beanie.exceptions import DocumentWasNotSaved
from beanie.odm.fields import BackLink, DeleteRules, Link, WriteRules
from beanie.odm.fields import (
BackLink,
DeleteRules,
Link,
WriteRules,
)
from beanie.odm.utils.pydantic import (
IS_PYDANTIC_V2,
get_model_fields,
Expand Down Expand Up @@ -833,9 +838,11 @@ async def test_find_aggregate_without_fetch_links(self, houses):
]
)
assert aggregation.get_aggregation_pipeline() == [
{"$match": {"door._id": door.id}},
{"$match": {"door.$id": door.id}},
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]
result = await aggregation.to_list()
assert result == [{"_id": 0, "count": 1}]

async def test_find_aggregate_with_fetch_links(self, houses):
door = await Door.find_one()
Expand All @@ -847,3 +854,9 @@ async def test_find_aggregate_with_fetch_links(self, houses):
]
)
assert len(aggregation.get_aggregation_pipeline()) == 12
assert aggregation.get_aggregation_pipeline()[10:] == [
{"$match": {"door._id": door.id}},
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]
result = await aggregation.to_list()
assert result == [{"_id": 0, "count": 1}]

0 comments on commit 5fcf24c

Please sign in to comment.