Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build aggregation pipeline from find query without fetch #770

Merged
merged 2 commits into from Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 15 additions & 9 deletions beanie/odm/queries/find.py
Expand Up @@ -84,22 +84,25 @@ def __init__(self, document_model: Type["DocType"]):
self.pymongo_kwargs: Dict[str, Any] = {}
self.lazy_parse = False

def prepare_find_expressions(self):
def prepare_find_expressions(self, for_aggregation: bool = False):
for_aggregation = for_aggregation or self.fetch_links
if self.document_model.get_link_fields() is not None:
for i, query in enumerate(self.find_expressions):
self.find_expressions[i] = convert_ids(
query,
doc=self.document_model,
fetch_links=self.fetch_links,
doc=self.document_model, # type: ignore
for_aggregation=for_aggregation,
)

def get_filter_query(self) -> Mapping[str, Any]:
def get_filter_query(
self, for_aggregation: bool = False
) -> Mapping[str, Any]:
"""

Returns: MongoDB filter query

"""
self.prepare_find_expressions()
self.prepare_find_expressions(for_aggregation=for_aggregation)
if self.find_expressions:
return Encoder(custom_encoders=self.encoders).encode(
And(*self.find_expressions).query
Expand Down Expand Up @@ -599,10 +602,13 @@ def _set_cache(self, data):
)

def build_aggregation_pipeline(self, *extra_stages):
aggregation_pipeline: List[Dict[str, Any]] = construct_lookup_queries(
self.document_model
)
filter_query = self.get_filter_query()
if self.fetch_links:
aggregation_pipeline: List[
Dict[str, Any]
] = construct_lookup_queries(self.document_model)
else:
aggregation_pipeline = []
filter_query = self.get_filter_query(for_aggregation=True)

if filter_query:
text_queries, non_text_queries = split_text_query(filter_query)
Expand Down
8 changes: 4 additions & 4 deletions beanie/odm/utils/relations.py
Expand Up @@ -14,7 +14,7 @@


def convert_ids(
query: MappingType[str, Any], doc: "Document", fetch_links: bool
query: MappingType[str, Any], doc: "Document", for_aggregation: bool
) -> Dict[str, Any]:
# TODO add all the cases
new_query = {}
Expand All @@ -27,18 +27,18 @@ def convert_ids(
and k_splitted[0] in doc.get_link_fields().keys() # type: ignore
and k_splitted[1] == "id"
):
if fetch_links:
if for_aggregation:
new_k = f"{k_splitted[0]}._id"
else:
new_k = f"{k_splitted[0]}.$id"
else:
new_k = k
new_v: Any
if isinstance(v, Mapping):
new_v = convert_ids(v, doc, fetch_links)
new_v = convert_ids(v, doc, for_aggregation)
elif isinstance(v, list):
new_v = [
convert_ids(ele, doc, fetch_links)
convert_ids(ele, doc, for_aggregation)
if isinstance(ele, Mapping)
else ele
for ele in v
Expand Down
25 changes: 25 additions & 0 deletions tests/odm/test_relations.py
Expand Up @@ -822,3 +822,28 @@ async def test_init_reversed_order(self, db):
PersonForReversedOrderInit,
],
)


class TestBuildAggregations:
async def test_find_aggregate_without_fetch_links(self, houses):
door = await Door.find_one()
aggregation = House.find(House.door.id == door.id).aggregate(
[
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]
)
assert aggregation.get_aggregation_pipeline() == [
{"$match": {"door._id": door.id}},
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]

async def test_find_aggregate_with_fetch_links(self, houses):
door = await Door.find_one()
aggregation = House.find(
House.door.id == door.id, fetch_links=True
).aggregate(
[
{"$group": {"_id": "$height", "count": {"$sum": 1}}},
]
)
assert len(aggregation.get_aggregation_pipeline()) == 12