Skip to content

Commit

Permalink
Define Filter type object (#1631)
Browse files Browse the repository at this point in the history
* Improve filters type hint

* fix static code analysis complaint
  • Loading branch information
waketzheng authored May 29, 2024
1 parent 88acc8d commit 88f0366
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 40 deletions.
5 changes: 3 additions & 2 deletions tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ForeignKeyFieldInstance,
RelationalField,
)
from tortoise.filters import FilterInfoDict
from tortoise.query_utils import QueryModifier, _get_joins_for_related_field

if TYPE_CHECKING: # pragma: nocoverage
Expand Down Expand Up @@ -154,14 +155,14 @@ def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None:
#: Contains the sub-Q's that this Q is made up of
self.children: Tuple[Q, ...] = args
#: Contains the filters applied to this Q
self.filters: Dict[str, Any] = kwargs
self.filters: Dict[str, FilterInfoDict] = kwargs
if join_type not in {self.AND, self.OR}:
raise OperationalError("join_type must be AND or OR")
#: Specifies if this Q does an AND or OR on its children
self.join_type = join_type
self._is_negated = False
self._annotations: Dict[str, Any] = {}
self._custom_filters: Dict[str, Dict[str, Any]] = {}
self._custom_filters: Dict[str, FilterInfoDict] = {}

def __and__(self, other: "Q") -> "Q":
"""
Expand Down
31 changes: 26 additions & 5 deletions tortoise/filters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import operator
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Optional,
Tuple,
TypedDict,
)

from pypika import Table
from pypika.enums import DatePart, SqlTypes
from pypika.functions import Cast, Extract, Upper
from pypika.terms import BasicCriterion, Criterion, Equality, Term, ValueWrapper
from typing_extensions import NotRequired

from tortoise.fields import Field, JSONField
from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance
Expand Down Expand Up @@ -206,7 +216,16 @@ def json_filter(field: Term, value: Dict) -> Criterion:
##############################################################################


def get_m2m_filters(field_name: str, field: ManyToManyFieldInstance) -> Dict[str, dict]:
class FilterInfoDict(TypedDict):
field: str
operator: Callable
backward_key: NotRequired[str]
table: NotRequired[Table]
value_encoder: NotRequired[Callable]
source_field: NotRequired[str]


def get_m2m_filters(field_name: str, field: ManyToManyFieldInstance) -> Dict[str, FilterInfoDict]:
target_table_pk = field.related_model._meta.pk
return {
field_name: {
Expand Down Expand Up @@ -240,7 +259,9 @@ def get_m2m_filters(field_name: str, field: ManyToManyFieldInstance) -> Dict[str
}


def get_backward_fk_filters(field_name: str, field: BackwardFKRelation) -> Dict[str, dict]:
def get_backward_fk_filters(
field_name: str, field: BackwardFKRelation
) -> Dict[str, FilterInfoDict]:
target_table_pk = field.related_model._meta.pk
return {
field_name: {
Expand Down Expand Up @@ -286,7 +307,7 @@ def get_backward_fk_filters(field_name: str, field: BackwardFKRelation) -> Dict[
}


def get_json_filter(field_name: str, source_field: str):
def get_json_filter(field_name: str, source_field: str) -> Dict[str, FilterInfoDict]:
actual_field_name = field_name
return {
field_name: {
Expand Down Expand Up @@ -332,7 +353,7 @@ def get_json_filter(field_name: str, source_field: str):

def get_filters_for_field(
field_name: str, field: Optional[Field], source_field: str
) -> Dict[str, dict]:
) -> Dict[str, FilterInfoDict]:
if isinstance(field, ManyToManyFieldInstance):
return get_m2m_filters(field_name, field)
if isinstance(field, BackwardFKRelation):
Expand Down
37 changes: 13 additions & 24 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
OneToOneFieldInstance,
ReverseRelation,
)
from tortoise.filters import get_filters_for_field
from tortoise.filters import FilterInfoDict, get_filters_for_field
from tortoise.functions import Function
from tortoise.indexes import Index
from tortoise.manager import Manager
Expand All @@ -66,9 +66,6 @@
EMPTY = object()


# TODO: Define Filter type object. Possibly tuple?


def get_together(meta: "Model.Meta", together: str) -> Tuple[Tuple[str, ...], ...]:
_together = getattr(meta, together, ())

Expand Down Expand Up @@ -234,8 +231,8 @@ def __init__(self, meta: "Model.Meta") -> None:
self.fetch_fields: Set[str] = set()
self.fields_db_projection: Dict[str, str] = {}
self.fields_db_projection_reverse: Dict[str, str] = {}
self._filters: Dict[str, Dict[str, dict]] = {}
self.filters: Dict[str, dict] = {}
self._filters: Dict[str, FilterInfoDict] = {}
self.filters: Dict[str, FilterInfoDict] = {}
self.fields_map: Dict[str, Field] = {}
self._inited: bool = False
self.default_connection: Optional[str] = None
Expand Down Expand Up @@ -297,7 +294,7 @@ def ordering(self) -> Tuple[Tuple[str, Order], ...]:
)
return self._default_ordering

def get_filter(self, key: str) -> dict:
def get_filter(self, key: str) -> FilterInfoDict:
return self.filters[key]

def finalise_model(self) -> None:
Expand Down Expand Up @@ -473,12 +470,10 @@ def _generate_db_fields(self) -> None:
def _generate_filters(self) -> None:
get_overridden_filter_func = self.db.executor_class.get_overridden_filter_func
for key, filter_info in self._filters.items():
overridden_operator = get_overridden_filter_func(
filter_func=filter_info["operator"] # type: ignore
)
overridden_operator = get_overridden_filter_func(filter_func=filter_info["operator"])
if overridden_operator:
filter_info = copy(filter_info)
filter_info["operator"] = overridden_operator # type: ignore
filter_info["operator"] = overridden_operator
self.filters[key] = filter_info


Expand All @@ -488,7 +483,7 @@ class ModelMeta(type):
def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict):
fields_db_projection: Dict[str, str] = {}
fields_map: Dict[str, Field] = {}
filters: Dict[str, Dict[str, dict]] = {}
filters: Dict[str, FilterInfoDict] = {}
fk_fields: Set[str] = set()
m2m_fields: Set[str] = set()
o2o_fields: Set[str] = set()
Expand Down Expand Up @@ -582,19 +577,16 @@ def __search_for_field_attributes(base: Type, attrs: dict) -> None:
m2m_fields.add(key)
else:
fields_db_projection[key] = value.source_field or key
field, source_field = fields_map[key], fields_db_projection[key]
filters.update(
get_filters_for_field(
field_name=key,
field=fields_map[key],
source_field=fields_db_projection[key],
field_name=key, field=field, source_field=source_field
)
)
if value.pk:
filters.update(
get_filters_for_field(
field_name="pk",
field=fields_map[key],
source_field=fields_db_projection[key],
field_name="pk", field=field, source_field=source_field
)
)

Expand Down Expand Up @@ -740,12 +732,9 @@ def _init_from_db(cls: Type[MODEL], **kwargs: Any) -> MODEL:
# Fields that don't override .to_python_value() are converted without a call
# as we already know what we will be doing.
for key, model_field, field in meta.db_default_fields:
value = kwargs[key]
setattr(
self,
model_field,
None if value is None else field.field_type(value),
)
if (value := kwargs[key]) is not None:
value = field.field_type(value)
setattr(self, model_field, value)
# These fields need manual .to_python_value()
for key, model_field, field in meta.db_complex_fields:
setattr(self, model_field, field.to_python_value(kwargs[key]))
Expand Down
19 changes: 10 additions & 9 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
OneToOneFieldInstance,
RelationalField,
)
from tortoise.filters import FilterInfoDict
from tortoise.functions import Function
from tortoise.query_utils import Prefetch, QueryModifier, _get_joins_for_related_field
from tortoise.router import router
Expand Down Expand Up @@ -122,7 +123,7 @@ def resolve_filters(
model: "Type[Model]",
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
) -> None:
"""
Builds the common filters for a QuerySet.
Expand Down Expand Up @@ -318,7 +319,7 @@ def __init__(self, model: Type[MODEL]) -> None:
self._q_objects: List[Q] = []
self._distinct: bool = False
self._having: Dict[str, Any] = {}
self._custom_filters: Dict[str, dict] = {}
self._custom_filters: Dict[str, FilterInfoDict] = {}
self._fields_for_select: Tuple[str, ...] = ()
self._group_bys: Tuple[str, ...] = ()
self._select_for_update: bool = False
Expand Down Expand Up @@ -1094,7 +1095,7 @@ def __init__(
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
) -> None:
Expand Down Expand Up @@ -1180,7 +1181,7 @@ def __init__(
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
) -> None:
Expand Down Expand Up @@ -1235,7 +1236,7 @@ def __init__(
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
Expand Down Expand Up @@ -1293,7 +1294,7 @@ def __init__(
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
offset: Optional[int],
force_indexes: Set[str],
Expand Down Expand Up @@ -1486,7 +1487,7 @@ def __init__(
orderings: List[Tuple[str, str]],
flat: bool,
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
group_bys: Tuple[str, ...],
force_indexes: Set[str],
use_indexes: Set[str],
Expand Down Expand Up @@ -1620,7 +1621,7 @@ def __init__(
distinct: bool,
orderings: List[Tuple[str, str]],
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
group_bys: Tuple[str, ...],
force_indexes: Set[str],
use_indexes: Set[str],
Expand Down Expand Up @@ -1756,7 +1757,7 @@ def __init__(
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, Dict[str, Any]],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
objects: Iterable[MODEL],
Expand Down

0 comments on commit 88f0366

Please sign in to comment.