diff --git a/tortoise/query_utils.py b/tortoise/query_utils.py index b0f8873b9..68d4084bd 100644 --- a/tortoise/query_utils.py +++ b/tortoise/query_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from copy import copy from typing import TYPE_CHECKING, List, Optional, Tuple @@ -15,9 +17,12 @@ from tortoise.queryset import QuerySet +TableCriterionTuple = Tuple[Table, Criterion] + + def _get_joins_for_related_field( table: Table, related_field: RelationalField, related_field_name: str -) -> List[Tuple[Table, Criterion]]: +) -> List[TableCriterionTuple]: required_joins = [] related_table: Table = related_field.related_model._meta.basetable @@ -70,7 +75,7 @@ def _get_joins_for_related_field( return required_joins -class EmptyCriterion(Criterion): # type: ignore +class EmptyCriterion(Criterion): # type:ignore[misc] def __or__(self, other: Criterion) -> Criterion: return other @@ -101,54 +106,44 @@ class QueryModifier: def __init__( self, where_criterion: Optional[Criterion] = None, - joins: Optional[List[Tuple[Table, Criterion]]] = None, + joins: Optional[List[TableCriterionTuple]] = None, having_criterion: Optional[Criterion] = None, ) -> None: self.where_criterion: Criterion = where_criterion or EmptyCriterion() - self.joins = joins if joins else [] + self.joins = joins or [] self.having_criterion: Criterion = having_criterion or EmptyCriterion() - def __and__(self, other: "QueryModifier") -> "QueryModifier": - return QueryModifier( + def __and__(self, other: QueryModifier) -> QueryModifier: + return self.__class__( where_criterion=_and(self.where_criterion, other.where_criterion), joins=self.joins + other.joins, having_criterion=_and(self.having_criterion, other.having_criterion), ) - def __or__(self, other: "QueryModifier") -> "QueryModifier": - if self.having_criterion or other.having_criterion: - # TODO: This could be optimized? - result_having_criterion = _or( - _and(self.where_criterion, self.having_criterion), - _and(other.where_criterion, other.having_criterion), - ) - return QueryModifier( - joins=self.joins + other.joins, having_criterion=result_having_criterion - ) + def _and_criterion(self) -> Criterion: + return _and(self.where_criterion, self.having_criterion) - if self.where_criterion and other.where_criterion: - return QueryModifier( - where_criterion=self.where_criterion | other.where_criterion, - joins=self.joins + other.joins, + def __or__(self, other: QueryModifier) -> QueryModifier: + where_criterion = having_criterion = None + if self.having_criterion or other.having_criterion: + having_criterion = _or(self._and_criterion(), other._and_criterion()) + else: + where_criterion = ( + (self.where_criterion | other.where_criterion) + if self.where_criterion and other.where_criterion + else (self.where_criterion or other.where_criterion) ) + return self.__class__(where_criterion, self.joins + other.joins, having_criterion) - return QueryModifier( - where_criterion=self.where_criterion or other.where_criterion, - joins=self.joins + other.joins, - ) - - def __invert__(self) -> "QueryModifier": - if not self.where_criterion and not self.having_criterion: - return QueryModifier(joins=self.joins) + def __invert__(self) -> QueryModifier: + where_criterion = having_criterion = None if self.having_criterion: - # TODO: This could be optimized? - return QueryModifier( - joins=self.joins, - having_criterion=_and(self.where_criterion, self.having_criterion).negate(), - ) - return QueryModifier(where_criterion=self.where_criterion.negate(), joins=self.joins) + having_criterion = (self.where_criterion & self.having_criterion).negate() + elif self.where_criterion: + where_criterion = self.where_criterion.negate() + return self.__class__(where_criterion, self.joins, having_criterion) - def get_query_modifiers(self) -> Tuple[Criterion, List[Tuple[Table, Criterion]], Criterion]: + def get_query_modifiers(self) -> Tuple[Criterion, List[TableCriterionTuple], Criterion]: """ Returns a tuple of the query criterion. """ @@ -188,7 +183,7 @@ def resolve_for_queryset(self, queryset: "QuerySet") -> None: ) if forwarded_prefetch: - if first_level_field not in queryset._prefetch_map.keys(): + if first_level_field not in queryset._prefetch_map: queryset._prefetch_map[first_level_field] = set() queryset._prefetch_map[first_level_field].add( Prefetch(forwarded_prefetch, self.queryset, to_attr=self.to_attr)