Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: optimize query utils #1610

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 31 additions & 36 deletions tortoise/query_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from copy import copy
from typing import TYPE_CHECKING, List, Optional, Tuple

Expand All @@ -15,9 +17,12 @@
from tortoise.queryset import QuerySet


TableCriterionTuple = Tuple[Table, Criterion]


def _get_joins_for_related_field(
table: Table, related_field: RelationalField, related_field_name: str
) -> List[Tuple[Table, Criterion]]:
) -> List[TableCriterionTuple]:
required_joins = []

related_table: Table = related_field.related_model._meta.basetable
Expand Down Expand Up @@ -70,7 +75,7 @@ def _get_joins_for_related_field(
return required_joins


class EmptyCriterion(Criterion): # type: ignore
class EmptyCriterion(Criterion): # type:ignore[misc]
def __or__(self, other: Criterion) -> Criterion:
return other

Expand Down Expand Up @@ -101,54 +106,44 @@ class QueryModifier:
def __init__(
self,
where_criterion: Optional[Criterion] = None,
joins: Optional[List[Tuple[Table, Criterion]]] = None,
joins: Optional[List[TableCriterionTuple]] = None,
having_criterion: Optional[Criterion] = None,
) -> None:
self.where_criterion: Criterion = where_criterion or EmptyCriterion()
self.joins = joins if joins else []
self.joins = joins or []
self.having_criterion: Criterion = having_criterion or EmptyCriterion()

def __and__(self, other: "QueryModifier") -> "QueryModifier":
return QueryModifier(
def __and__(self, other: QueryModifier) -> QueryModifier:
return self.__class__(
where_criterion=_and(self.where_criterion, other.where_criterion),
joins=self.joins + other.joins,
having_criterion=_and(self.having_criterion, other.having_criterion),
)

def __or__(self, other: "QueryModifier") -> "QueryModifier":
if self.having_criterion or other.having_criterion:
# TODO: This could be optimized?
result_having_criterion = _or(
_and(self.where_criterion, self.having_criterion),
_and(other.where_criterion, other.having_criterion),
)
return QueryModifier(
joins=self.joins + other.joins, having_criterion=result_having_criterion
)
def _and_criterion(self) -> Criterion:
return _and(self.where_criterion, self.having_criterion)

if self.where_criterion and other.where_criterion:
return QueryModifier(
where_criterion=self.where_criterion | other.where_criterion,
joins=self.joins + other.joins,
def __or__(self, other: QueryModifier) -> QueryModifier:
where_criterion = having_criterion = None
if self.having_criterion or other.having_criterion:
having_criterion = _or(self._and_criterion(), other._and_criterion())
else:
where_criterion = (
(self.where_criterion | other.where_criterion)
if self.where_criterion and other.where_criterion
else (self.where_criterion or other.where_criterion)
)
return self.__class__(where_criterion, self.joins + other.joins, having_criterion)

return QueryModifier(
where_criterion=self.where_criterion or other.where_criterion,
joins=self.joins + other.joins,
)

def __invert__(self) -> "QueryModifier":
if not self.where_criterion and not self.having_criterion:
return QueryModifier(joins=self.joins)
def __invert__(self) -> QueryModifier:
where_criterion = having_criterion = None
if self.having_criterion:
# TODO: This could be optimized?
return QueryModifier(
joins=self.joins,
having_criterion=_and(self.where_criterion, self.having_criterion).negate(),
)
return QueryModifier(where_criterion=self.where_criterion.negate(), joins=self.joins)
having_criterion = (self.where_criterion & self.having_criterion).negate()
elif self.where_criterion:
where_criterion = self.where_criterion.negate()
return self.__class__(where_criterion, self.joins, having_criterion)

def get_query_modifiers(self) -> Tuple[Criterion, List[Tuple[Table, Criterion]], Criterion]:
def get_query_modifiers(self) -> Tuple[Criterion, List[TableCriterionTuple], Criterion]:
"""
Returns a tuple of the query criterion.
"""
Expand Down Expand Up @@ -188,7 +183,7 @@ def resolve_for_queryset(self, queryset: "QuerySet") -> None:
)

if forwarded_prefetch:
if first_level_field not in queryset._prefetch_map.keys():
if first_level_field not in queryset._prefetch_map:
queryset._prefetch_map[first_level_field] = set()
queryset._prefetch_map[first_level_field].add(
Prefetch(forwarded_prefetch, self.queryset, to_attr=self.to_attr)
Expand Down