diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 2111f4003..29817a31e 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -23,12 +23,12 @@ ) from tortoise.filters import get_m2m_filters from tortoise.log import logger -from tortoise.models import Model, ModelMeta +from tortoise.models import Model, ModelMeta, MODEL_CLASS, MODEL_INSTANCE from tortoise.utils import generate_schema_for_client class Tortoise: - apps: Dict[str, Dict[str, Type["Model"]]] = {} + apps: Dict[str, Dict[str, MODEL_CLASS]] = {} _inited: bool = False @classmethod @@ -46,7 +46,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: @classmethod def describe_model( - cls, model: Type["Model"], serializable: bool = True + cls, model: MODEL_CLASS, serializable: bool = True ) -> dict: # pragma: nocoverage """ Describes the given list of models or ALL registered models. @@ -64,7 +64,7 @@ def describe_model( This is deprecated, please use :meth:`tortoise.models.Model.describe` instead """ warnings.warn( - "Tortoise.describe_model() is deprecated, please use .describe() instead", + "Tortoise.describe_model() is deprecated, please use .describe() instead", DeprecationWarning, stacklevel=2, ) @@ -72,7 +72,7 @@ def describe_model( @classmethod def describe_models( - cls, models: Optional[List[Type["Model"]]] = None, serializable: bool = True + cls, models: Optional[List[MODEL_CLASS]] = None, serializable: bool = True ) -> Dict[str, dict]: """ Describes the given list of models or ALL registered models. @@ -108,7 +108,7 @@ def describe_models( @classmethod def _init_relations(cls) -> None: - def get_related_model(related_app_name: str, related_model_name: str) -> Type["Model"]: + def get_related_model(related_app_name: str, related_model_name: str) -> MODEL_CLASS: """ Test, if app and model really exist. Throws a ConfigurationError with a hopefully helpful message. If successful, returns the requested model. @@ -345,7 +345,7 @@ def split_reference(reference: str) -> Tuple[str, str]: @classmethod def _discover_models( cls, models_path: Union[ModuleType, str], app_label: str - ) -> List[Type["Model"]]: + ) -> List[MODEL_CLASS]: if isinstance(models_path, ModuleType): module = models_path else: @@ -390,7 +390,7 @@ def init_models( :raises ConfigurationError: If models are invalid. """ - app_models: List[Type[Model]] = [] + app_models: List[MODEL_CLASS] = [] for models_path in models_paths: app_models += cls._discover_models(models_path, app_label) diff --git a/tortoise/backends/asyncpg/executor.py b/tortoise/backends/asyncpg/executor.py index 4468f892a..25d3399d0 100644 --- a/tortoise/backends/asyncpg/executor.py +++ b/tortoise/backends/asyncpg/executor.py @@ -2,12 +2,12 @@ import asyncpg -from tortoise import Model +from tortoise import MODEL_INSTANCE from tortoise.backends.base_postgres.executor import BasePostgresExecutor class AsyncpgExecutor(BasePostgresExecutor): async def _process_insert_result( - self, instance: Model, results: Optional[asyncpg.Record] + self, instance: MODEL_INSTANCE, results: Optional[asyncpg.Record] ) -> None: return await super()._process_insert_result(instance, results) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 0b5f3b693..455b54ac8 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient - from tortoise.models import Model + from tortoise.models import MODEL_CLASS, MODEL_INSTANCE from tortoise.query_utils import Prefetch from tortoise.queryset import QuerySet @@ -55,12 +55,12 @@ class BaseExecutor: def __init__( self, - model: "Type[Model]", + model: "MODEL_CLASS", db: "BaseDBAsyncClient", prefetch_map: "Optional[Dict[str, Set[Union[str, Prefetch]]]]" = None, prefetch_queries: Optional[Dict[str, List[Tuple[Optional[str], "QuerySet"]]]] = None, select_related_idx: Optional[ - List[Tuple["Type[Model]", int, str, "Type[Model]", Iterable[Optional[str]]]] + List[Tuple["MODEL_CLASS", int, str, "MODEL_CLASS", Iterable[Optional[str]]]] ] = None, ) -> None: self.model = model @@ -136,7 +136,7 @@ async def execute_select( dict_row = dict(row) keys = list(dict_row.keys()) values = list(dict_row.values()) - instance: "Model" = self.model._init_from_db( + instance: "MODEL_INSTANCE" = self.model._init_from_db( **dict(zip(keys[:current_idx], values[:current_idx])) ) instances: Dict[Any, Any] = {path: instance} @@ -191,7 +191,7 @@ def _prepare_insert_columns( @classmethod def _field_to_db( - cls, field_object: Field, attr: Any, instance: "Union[Type[Model], Model]" + cls, field_object: Field, attr: Any, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"] ) -> Any: if field_object.__class__ in cls.TO_DB_OVERRIDE: return cls.TO_DB_OVERRIDE[field_object.__class__](field_object, attr, instance) @@ -212,13 +212,13 @@ def _prepare_insert_statement( query = query.on_conflict().do_nothing() return query - async def _process_insert_result(self, instance: "Model", results: Any) -> None: + async def _process_insert_result(self, instance: "MODEL_INSTANCE", results: Any) -> None: raise NotImplementedError() # pragma: nocoverage def parameter(self, pos: int) -> Parameter: raise NotImplementedError() # pragma: nocoverage - async def execute_insert(self, instance: "Model") -> None: + async def execute_insert(self, instance: "MODEL_INSTANCE") -> None: if not instance._custom_generated_pk: values = [ self.column_map[field_name](getattr(instance, field_name), instance) @@ -236,7 +236,7 @@ async def execute_insert(self, instance: "Model") -> None: async def execute_bulk_insert( self, - instances: "Iterable[Model]", + instances: Iterable["MODEL_INSTANCE"], batch_size: Optional[int] = None, ) -> None: for instance_chunk in chunk(instances, batch_size): @@ -300,7 +300,7 @@ def get_update_sql( return sql async def execute_update( - self, instance: "Union[Type[Model], Model]", update_fields: Optional[Iterable[str]] + self, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"], update_fields: Optional[Iterable[str]] ) -> int: values = [] arithmetic_or_function = {} @@ -319,7 +319,7 @@ async def execute_update( ) )[0] - async def execute_delete(self, instance: "Union[Type[Model], Model]") -> int: + async def execute_delete(self, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"]) -> int: return ( await self.db.execute_query( self.delete_query, [self.model._meta.pk.to_db_value(instance.pk, instance)] @@ -328,10 +328,10 @@ async def execute_delete(self, instance: "Union[Type[Model], Model]") -> int: async def _prefetch_reverse_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable["MODEL_INSTANCE"], field: str, related_query: Tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + ) -> Iterable["MODEL_INSTANCE"]: to_attr, related_query = related_query related_objects_for_fetch: Dict[str, list] = {} related_field: BackwardFKRelation = self.model._meta.fields_map[field] # type: ignore @@ -373,10 +373,10 @@ async def _prefetch_reverse_relation( async def _prefetch_reverse_o2o_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable["MODEL_INSTANCE"], field: str, related_query: Tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + ) -> Iterable["MODEL_INSTANCE"]: to_attr, related_query = related_query related_objects_for_fetch: Dict[str, list] = {} related_field: BackwardOneToOneRelation = self.model._meta.fields_map[field] # type: ignore @@ -416,10 +416,10 @@ async def _prefetch_reverse_o2o_relation( async def _prefetch_m2m_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable["MODEL_INSTANCE"], field: str, related_query: Tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + ) -> Iterable["MODEL_INSTANCE"]: to_attr, related_query = related_query instance_id_set: set = { self._field_to_db(instance._meta.pk, instance.pk, instance) @@ -502,10 +502,10 @@ async def _prefetch_m2m_relation( async def _prefetch_direct_relation( self, - instance_list: "Iterable[Model]", + instance_list: Iterable["MODEL_INSTANCE"], field: str, related_query: Tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + ) -> Iterable["MODEL_INSTANCE"]: # TODO: This will only work if instance_list is all of same type # TODO: If that's the case, then we can optimize the key resolver to_attr, related_query = related_query @@ -539,7 +539,7 @@ def _make_prefetch_queries(self) -> None: to_attr, related_query = self._prefetch_queries[field_name][0] else: relation_field = self.model._meta.fields_map[field_name] - related_model: "Type[Model]" = relation_field.related_model # type: ignore + related_model: "MODEL_CLASS" = relation_field.related_model # type: ignore related_query = related_model.all().using_db(self.db) related_query.query = copy(related_query.model._meta.basequery) if forwarded_prefetches: @@ -548,10 +548,10 @@ def _make_prefetch_queries(self) -> None: async def _do_prefetch( self, - instance_id_list: "Iterable[Model]", + instance_id_list: Iterable["MODEL_INSTANCE"], field: str, related_query: Tuple[Optional[str], "QuerySet"], - ) -> "Iterable[Model]": + ) -> Iterable["MODEL_INSTANCE"]: if field in self.model._meta.backward_fk_fields: return await self._prefetch_reverse_relation(instance_id_list, field, related_query) @@ -563,8 +563,8 @@ async def _do_prefetch( return await self._prefetch_direct_relation(instance_id_list, field, related_query) async def _execute_prefetch_queries( - self, instance_list: "Iterable[Model]" - ) -> "Iterable[Model]": + self, instance_list: Iterable["MODEL_INSTANCE"] + ) -> Iterable["MODEL_INSTANCE"]: if instance_list and (self.prefetch_map or self._prefetch_queries): self._make_prefetch_queries() prefetch_tasks = [] @@ -576,8 +576,8 @@ async def _execute_prefetch_queries( return instance_list async def fetch_for_list( - self, instance_list: "Iterable[Model]", *args: str - ) -> "Iterable[Model]": + self, instance_list: Iterable["MODEL_INSTANCE"], *args: str + ) -> Iterable["MODEL_INSTANCE"]: self.prefetch_map = {} for relation in args: first_level_field, __, forwarded_prefetch = relation.partition("__") diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index ac07d117f..d6d4f6854 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -1,5 +1,5 @@ from hashlib import sha256 -from typing import TYPE_CHECKING, Any, List, Set, Type, cast +from typing import TYPE_CHECKING, Any, List, Set, cast from tortoise.exceptions import ConfigurationError from tortoise.fields import JSONField, TextField, UUIDField @@ -8,8 +8,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.fields.relational import ForeignKeyFieldInstance # noqa - from tortoise.fields.relational import ManyToManyFieldInstance - from tortoise.models import Model + from tortoise.models import MODEL_CLASS # pylint: disable=R0201 @@ -134,7 +133,7 @@ def _make_hash(*args: str, length: int) -> str: return sha256(";".join(args).encode("utf-8")).hexdigest()[:length] def _generate_index_name( - self, prefix: str, model: "Type[Model]", field_names: List[str] + self, prefix: str, model: "MODEL_CLASS", field_names: List[str] ) -> str: # NOTE: for compatibility, index name should not be longer than 30 # characters (Oracle limit). @@ -161,7 +160,7 @@ def _generate_fk_name( ) return index_name - def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str: + def _get_index_sql(self, model: "MODEL_CLASS", field_names: List[str], safe: bool) -> str: return self.INDEX_CREATE_TEMPLATE.format( exists="IF NOT EXISTS " if safe else "", index_name=self._generate_index_name("idx", model, field_names), @@ -169,18 +168,18 @@ def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: boo fields=", ".join([self.quote(f) for f in field_names]), ) - def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: List[str]) -> str: + def _get_unique_constraint_sql(self, model: "MODEL_CLASS", field_names: List[str]) -> str: return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format( index_name=self._generate_index_name("uid", model, field_names), fields=", ".join([self.quote(f) for f in field_names]), ) - def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: "MODEL_CLASS", safe: bool = True) -> dict: fields_to_create = [] fields_with_index = [] m2m_tables_for_create = [] references = set() - models_to_create: "List[Type[Model]]" = [] + models_to_create: List["MODEL_CLASS"] = [] self._get_models_to_create(models_to_create) models_tables = [model._meta.db_table for model in models_to_create] @@ -402,7 +401,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: "m2m_tables": m2m_tables_for_create, } - def _get_models_to_create(self, models_to_create: "List[Type[Model]]") -> None: + def _get_models_to_create(self, models_to_create: List["MODEL_CLASS"]) -> None: from tortoise import Tortoise for app in Tortoise.apps.values(): @@ -412,7 +411,7 @@ def _get_models_to_create(self, models_to_create: "List[Type[Model]]") -> None: models_to_create.append(model) def get_create_schema_sql(self, safe: bool = True) -> str: - models_to_create: "List[Type[Model]]" = [] + models_to_create: List["MODEL_CLASS"] = [] self._get_models_to_create(models_to_create) diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index db34a3748..1c22dd88a 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -5,7 +5,7 @@ from pypika.dialects import PostgreSQLQueryBuilder from pypika.terms import Term -from tortoise import Model +from tortoise import MODEL_INSTANCE from tortoise.backends.base.executor import BaseExecutor from tortoise.contrib.postgres.json_functions import ( postgres_json_contained_by, @@ -49,7 +49,7 @@ def _prepare_insert_statement( query = query.on_conflict().do_nothing() return query - async def _process_insert_result(self, instance: Model, results: Optional[dict]) -> None: + async def _process_insert_result(self, instance: MODEL_INSTANCE, results: Optional[dict]) -> None: if results: generated_fields = self.model._meta.generated_db_fields db_projection = instance._meta.fields_db_projection_reverse diff --git a/tortoise/backends/mssql/executor.py b/tortoise/backends/mssql/executor.py index bc14c9cfa..598102541 100644 --- a/tortoise/backends/mssql/executor.py +++ b/tortoise/backends/mssql/executor.py @@ -1,15 +1,15 @@ -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from pypika import Query -from tortoise import Model, fields +from tortoise import MODEL_CLASS, MODEL_INSTANCE, fields from tortoise.backends.odbc.executor import ODBCExecutor from tortoise.exceptions import UnSupportedError from tortoise.fields import BooleanField def to_db_bool( - self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model] + self: BooleanField, value: Optional[Union[bool, int]], instance: Union[MODEL_CLASS, MODEL_INSTANCE] ) -> Optional[int]: if value is None: return None diff --git a/tortoise/backends/mssql/schema_generator.py b/tortoise/backends/mssql/schema_generator.py index 2706ae583..f36692d31 100644 --- a/tortoise/backends/mssql/schema_generator.py +++ b/tortoise/backends/mssql/schema_generator.py @@ -1,11 +1,11 @@ -from typing import TYPE_CHECKING, Any, List, Type +from typing import TYPE_CHECKING, Any, List from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.converters import encoders if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.mssql import MSSQLClient - from tortoise.models import Model + from tortoise.models import MODEL_CLASS class MSSQLSchemaGenerator(BaseSchemaGenerator): @@ -59,10 +59,10 @@ def _column_default_generator( def _escape_default_value(self, default: Any): return encoders.get(type(default))(default) # type: ignore - def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str: + def _get_index_sql(self, model: "MODEL_CLASS", field_names: List[str], safe: bool) -> str: return super(MSSQLSchemaGenerator, self)._get_index_sql(model, field_names, False) - def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: "MODEL_CLASS", safe: bool = True) -> dict: return super(MSSQLSchemaGenerator, self)._get_table_sql(model, False) def _create_fk_string( diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 0144b26b5..1f3e19985 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -3,7 +3,7 @@ from pypika.terms import Criterion from pypika.utils import format_quotes -from tortoise import Model +from tortoise import MODEL_INSTANCE from tortoise.backends.base.executor import BaseExecutor from tortoise.contrib.mysql.json_functions import ( mysql_json_contained_by, @@ -114,7 +114,7 @@ class MySQLExecutor(BaseExecutor): def parameter(self, pos: int) -> Parameter: return Parameter("%s") - async def _process_insert_result(self, instance: Model, results: int) -> None: + async def _process_insert_result(self, instance: MODEL_INSTANCE, results: int) -> None: pk_field_object = self.model._meta.pk if ( isinstance(pk_field_object, (SmallIntField, IntField, BigIntField)) diff --git a/tortoise/backends/mysql/schema_generator.py b/tortoise/backends/mysql/schema_generator.py index e10b41afc..26c2fadec 100644 --- a/tortoise/backends/mysql/schema_generator.py +++ b/tortoise/backends/mysql/schema_generator.py @@ -1,11 +1,11 @@ -from typing import TYPE_CHECKING, Any, List, Type +from typing import TYPE_CHECKING, Any, List from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.converters import encoders if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.mysql.client import MySQLClient - from tortoise.models import Model + from tortoise.models import MODEL_CLASS class MySQLSchemaGenerator(BaseSchemaGenerator): @@ -67,7 +67,7 @@ def _column_default_generator( def _escape_default_value(self, default: Any): return encoders.get(type(default))(default) # type: ignore - def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str: + def _get_index_sql(self, model: "MODEL_CLASS", field_names: List[str], safe: bool) -> str: """Get index SQLs, but keep them for ourselves""" self._field_indexes.append( self.INDEX_CREATE_TEMPLATE.format( diff --git a/tortoise/backends/odbc/executor.py b/tortoise/backends/odbc/executor.py index 620fbb239..3e9aeabde 100644 --- a/tortoise/backends/odbc/executor.py +++ b/tortoise/backends/odbc/executor.py @@ -1,6 +1,6 @@ from pypika import Parameter -from tortoise import Model +from tortoise import MODEL_INSTANCE from tortoise.backends.base.executor import BaseExecutor from tortoise.fields import BigIntField, IntField, SmallIntField @@ -9,7 +9,7 @@ class ODBCExecutor(BaseExecutor): def parameter(self, pos: int) -> Parameter: return Parameter("?") - async def _process_insert_result(self, instance: Model, results: int) -> None: + async def _process_insert_result(self, instance: MODEL_INSTANCE, results: int) -> None: pk_field_object = self.model._meta.pk if ( isinstance(pk_field_object, (SmallIntField, IntField, BigIntField)) diff --git a/tortoise/backends/oracle/executor.py b/tortoise/backends/oracle/executor.py index 1317f2e8f..1a3e6ff98 100644 --- a/tortoise/backends/oracle/executor.py +++ b/tortoise/backends/oracle/executor.py @@ -1,9 +1,9 @@ -from tortoise import Model +from tortoise import MODEL_INSTANCE from tortoise.backends.odbc.executor import ODBCExecutor class OracleExecutor(ODBCExecutor): - async def _process_insert_result(self, instance: Model, results: int) -> None: + async def _process_insert_result(self, instance: MODEL_INSTANCE, results: int) -> None: sql = "SELECT SEQUENCE_NAME FROM ALL_TAB_IDENTITY_COLS where TABLE_NAME = ? and OWNER = ?" ret = await self.db.execute_query_dict( sql, values=[instance._meta.db_table, self.db.database] # type: ignore diff --git a/tortoise/backends/oracle/schema_generator.py b/tortoise/backends/oracle/schema_generator.py index c08d6e0be..22e11e365 100644 --- a/tortoise/backends/oracle/schema_generator.py +++ b/tortoise/backends/oracle/schema_generator.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, List, Type +from typing import TYPE_CHECKING, Any, List from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.converters import encoders @@ -6,7 +6,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.oracle import OracleClient - from tortoise.models import Model + from tortoise.models import MODEL_CLASS class OracleSchemaGenerator(BaseSchemaGenerator): @@ -85,10 +85,10 @@ def _column_default_generator( def _escape_default_value(self, default: Any): return encoders.get(type(default))(default) # type: ignore - def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str: + def _get_index_sql(self, model: "MODEL_CLASS", field_names: List[str], safe: bool) -> str: return super(OracleSchemaGenerator, self)._get_index_sql(model, field_names, False) - def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: "MODEL_CLASS", safe: bool = True) -> dict: return super(OracleSchemaGenerator, self)._get_table_sql(model, False) def _create_fk_string( diff --git a/tortoise/backends/psycopg/executor.py b/tortoise/backends/psycopg/executor.py index e53492494..079289f0d 100644 --- a/tortoise/backends/psycopg/executor.py +++ b/tortoise/backends/psycopg/executor.py @@ -4,13 +4,13 @@ from pypika import Parameter -from tortoise import Model +from tortoise import MODEL_INSTANCE from tortoise.backends.base_postgres.executor import BasePostgresExecutor class PsycopgExecutor(BasePostgresExecutor): async def _process_insert_result( - self, instance: Model, results: Optional[dict | tuple] + self, instance: MODEL_INSTANCE, results: Optional[dict | tuple] ) -> None: if results: db_projection = instance._meta.fields_db_projection_reverse diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index c1af22317..cf1dcdb76 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -1,11 +1,11 @@ import datetime from decimal import Decimal -from typing import Optional, Type, Union +from typing import Optional, Union import pytz from pypika import Parameter -from tortoise import Model, fields, timezone +from tortoise import MODEL_CLASS, MODEL_INSTANCE, fields, timezone from tortoise.backends.base.executor import BaseExecutor from tortoise.fields import ( BigIntField, @@ -19,7 +19,7 @@ def to_db_bool( - self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model] + self: BooleanField, value: Optional[Union[bool, int]], instance: Union[MODEL_CLASS, MODEL_INSTANCE] ) -> Optional[int]: if value is None: return None @@ -29,7 +29,7 @@ def to_db_bool( def to_db_decimal( self: DecimalField, value: Optional[Union[str, float, int, Decimal]], - instance: Union[Type[Model], Model], + instance: Union[MODEL_CLASS, MODEL_INSTANCE], ) -> Optional[str]: if value is None: return None @@ -37,7 +37,7 @@ def to_db_decimal( def to_db_datetime( - self: DatetimeField, value: Optional[datetime.datetime], instance: Union[Type[Model], Model] + self: DatetimeField, value: Optional[datetime.datetime], instance: Union[MODEL_CLASS, MODEL_INSTANCE] ) -> Optional[str]: # Only do this if it is a Model instance, not class. Test for guaranteed instance var if hasattr(instance, "_saved_in_db") and ( @@ -56,7 +56,7 @@ def to_db_datetime( def to_db_time( - self: TimeField, value: Optional[datetime.time], instance: Union[Type[Model], Model] + self: TimeField, value: Optional[datetime.time], instance: Union[MODEL_CLASS, MODEL_INSTANCE] ) -> Optional[str]: if hasattr(instance, "_saved_in_db") and ( self.auto_now @@ -86,7 +86,7 @@ class SqliteExecutor(BaseExecutor): def parameter(self, pos: int) -> Parameter: return Parameter("?") - async def _process_insert_result(self, instance: Model, results: int) -> None: + async def _process_insert_result(self, instance: MODEL_INSTANCE, results: int) -> None: pk_field_object = self.model._meta.pk if ( isinstance(pk_field_object, (SmallIntField, IntField, BigIntField)) diff --git a/tortoise/expressions.py b/tortoise/expressions.py index a3a3df093..2faea321e 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -7,7 +7,6 @@ List, Optional, Tuple, - Type, Union, cast, ) @@ -33,7 +32,7 @@ from pypika.queries import Selectable from tortoise.fields.base import Field - from tortoise.models import Model + from tortoise.models import MODEL_CLASS from tortoise.queryset import AwaitableQuery @@ -41,7 +40,7 @@ class F(PypikaField): # type: ignore @classmethod def resolver_arithmetic_expression( cls, - model: "Type[Model]", + model: "MODEL_CLASS", arithmetic_expression_or_field: Term, ) -> Tuple[Term, Optional[PypikaField]]: field_object = None @@ -118,7 +117,7 @@ class Expression: Parent class for expressions """ - def resolve(self, model: "Type[Model]", table: Table) -> Any: + def resolve(self, model: "MODEL_CLASS", table: Table) -> Any: raise NotImplementedError() @@ -198,7 +197,7 @@ def negate(self) -> None: self._is_negated = not self._is_negated def _resolve_nested_filter( - self, model: "Type[Model]", key: str, value: Any, table: Table + self, model: "MODEL_CLASS", key: str, value: Any, table: Table ) -> QueryModifier: related_field_name, __, forwarded_fields = key.partition("__") related_field = cast(RelationalField, model._meta.fields_map[related_field_name]) @@ -213,7 +212,7 @@ def _resolve_nested_filter( return QueryModifier(joins=required_joins) & modifier def _resolve_custom_kwarg( - self, model: "Type[Model]", key: str, value: Any, table: Table + self, model: "MODEL_CLASS", key: str, value: Any, table: Table ) -> QueryModifier: having_info = self._custom_filters[key] annotation = self._annotations[having_info["field"]] @@ -235,7 +234,7 @@ def _resolve_custom_kwarg( return modifier def _process_filter_kwarg( - self, model: "Type[Model]", key: str, value: Any, table: Table + self, model: "MODEL_CLASS", key: str, value: Any, table: Table ) -> Tuple[Criterion, Optional[Tuple[Table, Criterion]]]: join = None @@ -272,7 +271,7 @@ def _process_filter_kwarg( return criterion, join def _resolve_regular_kwarg( - self, model: "Type[Model]", key: str, value: Any, table: Table + self, model: "MODEL_CLASS", key: str, value: Any, table: Table ) -> QueryModifier: if key not in model._meta.filters and key.split("__")[0] in model._meta.fetch_fields: modifier = self._resolve_nested_filter(model, key, value, table) @@ -283,7 +282,7 @@ def _resolve_regular_kwarg( return modifier def _get_actual_filter_params( - self, model: "Type[Model]", key: str, value: Table + self, model: "MODEL_CLASS", key: str, value: Table ) -> Tuple[str, Any]: filter_key = key if key in model._meta.fk_fields or key in model._meta.o2o_fields: @@ -311,7 +310,7 @@ def _get_actual_filter_params( raise FieldError(f"Unknown filter param '{key}'. Allowed base values are {allowed}") return filter_key, filter_value - def _resolve_kwargs(self, model: "Type[Model]", table: Table) -> QueryModifier: + def _resolve_kwargs(self, model: "MODEL_CLASS", table: Table) -> QueryModifier: modifier = QueryModifier() for raw_key, raw_value in self.filters.items(): key, value = self._get_actual_filter_params(model, raw_key, raw_value) @@ -328,7 +327,7 @@ def _resolve_kwargs(self, model: "Type[Model]", table: Table) -> QueryModifier: modifier = ~modifier return modifier - def _resolve_children(self, model: "Type[Model]", table: Table) -> QueryModifier: + def _resolve_children(self, model: "MODEL_CLASS", table: Table) -> QueryModifier: modifier = QueryModifier() for node in self.children: node._annotations = self._annotations @@ -345,7 +344,7 @@ def _resolve_children(self, model: "Type[Model]", table: Table) -> QueryModifier def resolve( self, - model: "Type[Model]", + model: "MODEL_CLASS", table: Table, ) -> QueryModifier: """ @@ -396,7 +395,7 @@ def _get_function_field( ): return self.database_func(field, *default_values) - def _resolve_field_for_model(self, model: "Type[Model]", table: Table, field: str) -> dict: + def _resolve_field_for_model(self, model: "MODEL_CLASS", table: Table, field: str) -> dict: joins = [] fields = field.split("__") @@ -443,14 +442,14 @@ def _resolve_field_for_model(self, model: "Type[Model]", table: Table, field: st return {"joins": joins, "field": field} - def _resolve_default_values(self, model: "Type[Model]", table: Table) -> Iterator[Any]: + def _resolve_default_values(self, model: "MODEL_CLASS", table: Table) -> Iterator[Any]: for default_value in self.default_values: if isinstance(default_value, Function): yield default_value.resolve(model, table)["field"] else: yield default_value - def resolve(self, model: "Type[Model]", table: Table) -> dict: + def resolve(self, model: "MODEL_CLASS", table: Table) -> dict: """ Used to resolve the Function statement for SQL generation. @@ -506,7 +505,7 @@ def _get_function_field( return self.database_func(field, *default_values).distinct() return self.database_func(field, *default_values) - def _resolve_field_for_model(self, model: "Type[Model]", table: Table, field: str) -> dict: + def _resolve_field_for_model(self, model: "MODEL_CLASS", table: Table, field: str) -> dict: ret = super()._resolve_field_for_model(model, table, field) if self.filter: modifier = QueryModifier() @@ -556,7 +555,7 @@ def _resolve_q_objects(self) -> List[Q]: q_objects.append(Q(**{key: value})) return q_objects - def resolve(self, model: "Type[Model]", table: Table) -> tuple: + def resolve(self, model: "MODEL_CLASS", table: Table) -> tuple: q_objects = self._resolve_q_objects() modifier = QueryModifier() @@ -587,7 +586,7 @@ def __init__( self.args = args self.default = default - def resolve(self, model: "Type[Model]", table: Table) -> dict: + def resolve(self, model: "MODEL_CLASS", table: Table) -> dict: case = PypikaCase() for arg in self.args: if not isinstance(arg, When): diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index cbda76117..05ab961f7 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -20,7 +20,7 @@ from tortoise.validators import Validator if TYPE_CHECKING: # pragma: nocoverage - from tortoise.models import Model + from tortoise.models import MODEL_CLASS, MODEL_INSTANCE VALUE = TypeVar("VALUE") @@ -146,14 +146,14 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Field[VALUE]": return super().__new__(cls) @overload - def __get__(self, instance: None, owner: Type["Model"]) -> "Field[VALUE]": + def __get__(self, instance: None, owner: "MODEL_CLASS") -> "Field[VALUE]": ... @overload - def __get__(self, instance: "Model", owner: Type["Model"]) -> VALUE: + def __get__(self, instance: "MODEL_INSTANCE", owner: "MODEL_CLASS") -> VALUE: ... - def __get__(self, instance: Optional["Model"], owner: Type["Model"]): + def __get__(self, instance: Optional["MODEL_INSTANCE"], owner: "MODEL_CLASS"): ... def __init__( @@ -166,7 +166,7 @@ def __init__( unique: bool = False, index: bool = False, description: Optional[str] = None, - model: "Optional[Model]" = None, + model: Optional["MODEL_CLASS"] = None, validators: Optional[List[Union[Validator, Callable]]] = None, **kwargs: Any, ) -> None: @@ -193,10 +193,10 @@ def __init__( self.docstring: Optional[str] = None self.validators: List[Union[Validator, Callable]] = validators or [] # TODO: consider making this not be set from constructor - self.model: Type["Model"] = model # type: ignore + self.model: "MODEL_CLASS" = model # type: ignore self.reference: "Optional[Field]" = None - def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Any: + def to_db_value(self, value: Any, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"]) -> Any: """ Converts from the Python type to the DB type. diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index cd635924f..36da33d6a 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -25,7 +25,7 @@ parse_datetime = functools.partial(parse_date, default_timezone=None) if TYPE_CHECKING: # pragma: nocoverage - from tortoise.models import Model + from tortoise.models import MODEL_INSTANCE, MODEL_CLASS __all__ = ( "BigIntField", @@ -369,7 +369,7 @@ def to_python_value(self, value: Any) -> Optional[datetime.datetime]: return value def to_db_value( - self, value: Optional[datetime.datetime], instance: "Union[Type[Model], Model]" + self, value: Optional[datetime.datetime], instance: Union["MODEL_CLASS", "MODEL_INSTANCE"] ) -> Optional[datetime.datetime]: # Only do this if it is a Model instance, not class. Test for guaranteed instance var @@ -421,7 +421,7 @@ def to_python_value(self, value: Any) -> Optional[datetime.date]: return value def to_db_value( - self, value: Optional[Union[datetime.date, str]], instance: "Union[Type[Model], Model]" + self, value: Optional[Union[datetime.date, str]], instance: Union["MODEL_CLASS", "MODEL_INSTANCE"] ) -> Optional[datetime.date]: if value is not None and not isinstance(value, datetime.date): @@ -462,7 +462,7 @@ def to_python_value(self, value: Any) -> Optional[Union[datetime.time, datetime. def to_db_value( self, value: Optional[Union[datetime.time, datetime.timedelta]], - instance: "Union[Type[Model], Model]", + instance: Union["MODEL_CLASS", "MODEL_INSTANCE"], ) -> Optional[Union[datetime.time, datetime.timedelta]]: # Only do this if it is a Model instance, not class. Test for guaranteed instance var @@ -512,7 +512,7 @@ def to_python_value(self, value: Any) -> Optional[datetime.timedelta]: return datetime.timedelta(microseconds=value) def to_db_value( - self, value: Optional[datetime.timedelta], instance: "Union[Type[Model], Model]" + self, value: Optional[datetime.timedelta], instance: Union["MODEL_CLASS", "MODEL_INSTANCE"] ) -> Optional[int]: self.validate(value) @@ -575,7 +575,7 @@ def __init__( self.decoder = decoder def to_db_value( - self, value: Optional[Union[dict, list, str]], instance: "Union[Type[Model], Model]" + self, value: Optional[Union[dict, list, str]], instance: Union["MODEL_CLASS", "MODEL_INSTANCE"] ) -> Optional[str]: self.validate(value) @@ -621,7 +621,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["default"] = uuid4 super().__init__(**kwargs) - def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Optional[str]: + def to_db_value(self, value: Any, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"]) -> Optional[str]: return value and str(value) def to_python_value(self, value: Any) -> Optional[UUID]: @@ -685,7 +685,7 @@ def to_python_value(self, value: Union[int, None]) -> Union[IntEnum, None]: return value def to_db_value( - self, value: Union[IntEnum, None, int], instance: "Union[Type[Model], Model]" + self, value: Union[IntEnum, None, int], instance: Union["MODEL_CLASS", "MODEL_INSTANCE"] ) -> Union[int, None]: if isinstance(value, IntEnum): @@ -752,7 +752,7 @@ def to_python_value(self, value: Union[str, None]) -> Union[Enum, None]: return self.enum_type(value) if value is not None else None def to_db_value( - self, value: Union[Enum, None, str], instance: "Union[Type[Model], Model]" + self, value: Union[Enum, None, str], instance: Union["MODEL_CLASS", "MODEL_INSTANCE"] ) -> Union[str, None]: self.validate(value) if isinstance(value, Enum): diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index 4db08e9e6..fea387bf2 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -7,11 +7,8 @@ Iterator, List, Optional, - Type, - TypeVar, Union, - overload, -) + overload, TypeVar, TypeAlias, Type, ) from pypika import Table from typing_extensions import Literal @@ -21,10 +18,10 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient - from tortoise.models import Model from tortoise.queryset import Q, QuerySet -MODEL = TypeVar("MODEL", bound="Model") +MODEL_INSTANCE = TypeVar("MODEL_INSTANCE", bound="Model") +MODEL_CLASS: TypeAlias = Type[MODEL_INSTANCE] class _NoneAwaitable: @@ -40,16 +37,16 @@ def __bool__(self) -> bool: NoneAwaitable = _NoneAwaitable() -class ReverseRelation(Generic[MODEL]): +class ReverseRelation(Generic[MODEL_INSTANCE]): """ Relation container for :func:`.ForeignKeyField`. """ def __init__( self, - remote_model: Type[MODEL], + remote_model: MODEL_CLASS, relation_field: str, - instance: "Model", + instance: MODEL_INSTANCE, from_field: str, ) -> None: self.remote_model = remote_model @@ -58,10 +55,10 @@ def __init__( self.from_field = from_field self._fetched = False self._custom_query = False - self.related_objects: List[MODEL] = [] + self.related_objects: List[MODEL_INSTANCE] = [] @property - def _query(self) -> "QuerySet[MODEL]": + def _query(self) -> "QuerySet[MODEL_INSTANCE]": if not self.instance._saved_in_db: raise OperationalError( "This objects hasn't been instanced, call .save() before calling related queries" @@ -74,7 +71,7 @@ def __contains__(self, item: Any) -> bool: self._raise_if_not_fetched() return item in self.related_objects - def __iter__(self) -> "Iterator[MODEL]": + def __iter__(self) -> "Iterator[MODEL_INSTANCE]": self._raise_if_not_fetched() return self.related_objects.__iter__() @@ -86,51 +83,51 @@ def __bool__(self) -> bool: self._raise_if_not_fetched() return bool(self.related_objects) - def __getitem__(self, item: int) -> MODEL: + def __getitem__(self, item: int) -> MODEL_INSTANCE: self._raise_if_not_fetched() return self.related_objects[item] - def __await__(self) -> Generator[Any, None, List[MODEL]]: + def __await__(self) -> Generator[Any, None, List[MODEL_INSTANCE]]: return self._query.__await__() - async def __aiter__(self) -> AsyncGenerator[Any, MODEL]: + async def __aiter__(self) -> AsyncGenerator[Any, MODEL_INSTANCE]: if not self._fetched: self._set_result_for_query(await self) for val in self.related_objects: yield val - def filter(self, *args: "Q", **kwargs: Any) -> "QuerySet[MODEL]": + def filter(self, *args: "Q", **kwargs: Any) -> "QuerySet[MODEL_INSTANCE]": """ Returns a QuerySet with related elements filtered by args/kwargs. """ return self._query.filter(*args, **kwargs) - def all(self) -> "QuerySet[MODEL]": + def all(self) -> "QuerySet[MODEL_INSTANCE]": """ Returns a QuerySet with all related elements. """ return self._query - def order_by(self, *orderings: str) -> "QuerySet[MODEL]": + def order_by(self, *orderings: str) -> "QuerySet[MODEL_INSTANCE]": """ Returns a QuerySet related elements in order. """ return self._query.order_by(*orderings) - def limit(self, limit: int) -> "QuerySet[MODEL]": + def limit(self, limit: int) -> "QuerySet[MODEL_INSTANCE]": """ Returns a QuerySet with at most «limit» related elements. """ return self._query.limit(limit) - def offset(self, offset: int) -> "QuerySet[MODEL]": + def offset(self, offset: int) -> "QuerySet[MODEL_INSTANCE]": """ Returns a QuerySet with all related elements offset by «offset». """ return self._query.offset(offset) - def _set_result_for_query(self, sequence: List[MODEL], attr: Optional[str] = None) -> None: + def _set_result_for_query(self, sequence: List[MODEL_INSTANCE], attr: Optional[str] = None) -> None: self._fetched = True self.related_objects = sequence if attr: @@ -143,17 +140,17 @@ def _raise_if_not_fetched(self) -> None: ) -class ManyToManyRelation(ReverseRelation[MODEL]): +class ManyToManyRelation(ReverseRelation[MODEL_INSTANCE]): """ Many-to-many relation container for :func:`.ManyToManyField`. """ - def __init__(self, instance: "Model", m2m_field: "ManyToManyFieldInstance[MODEL]") -> None: + def __init__(self, instance: MODEL_INSTANCE, m2m_field: "ManyToManyFieldInstance[MODEL_INSTANCE]") -> None: super().__init__(m2m_field.related_model, m2m_field.related_name, instance, "pk") self.field = m2m_field self.instance = instance - async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: + async def add(self, *instances: MODEL_INSTANCE, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: """ Adds one or more of ``instances`` to the relation. @@ -235,7 +232,7 @@ async def clear(self, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: await db.execute_query(str(query)) async def remove( - self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = None + self, *instances: MODEL_INSTANCE, using_db: "Optional[BaseDBAsyncClient]" = None ) -> None: """ Removes one or more of ``instances`` from the relation. @@ -251,50 +248,50 @@ async def remove( if len(instances) == 1: condition = ( - through_table[self.field.forward_key] - == related_pk_formatting_func(instances[0].pk, instances[0]) - ) & ( - through_table[self.field.backward_key] - == pk_formatting_func(self.instance.pk, self.instance) - ) + through_table[self.field.forward_key] + == related_pk_formatting_func(instances[0].pk, instances[0]) + ) & ( + through_table[self.field.backward_key] + == pk_formatting_func(self.instance.pk, self.instance) + ) else: condition = ( - through_table[self.field.backward_key] - == pk_formatting_func(self.instance.pk, self.instance) - ) & ( - through_table[self.field.forward_key].isin( - [related_pk_formatting_func(i.pk, i) for i in instances] - ) - ) + through_table[self.field.backward_key] + == pk_formatting_func(self.instance.pk, self.instance) + ) & ( + through_table[self.field.forward_key].isin( + [related_pk_formatting_func(i.pk, i) for i in instances] + ) + ) query = db.query_class.from_(through_table).where(condition).delete() await db.execute_query(str(query)) -class RelationalField(Field[MODEL]): +class RelationalField(Field[MODEL_INSTANCE]): has_db_field = False def __init__( self, - related_model: "Type[MODEL]", + related_model: MODEL_CLASS, to_field: Optional[str] = None, db_constraint: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.related_model: "Type[MODEL]" = related_model + self.related_model: MODEL_CLASS = related_model self.to_field: str = to_field # type: ignore self.to_field_instance: Field = None # type: ignore self.db_constraint = db_constraint @overload - def __get__(self, instance: None, owner: Type["Model"]) -> "RelationalField[MODEL]": + def __get__(self, instance: None, owner: MODEL_CLASS) -> "RelationalField[MODEL_INSTANCE]": ... @overload - def __get__(self, instance: "Model", owner: Type["Model"]) -> MODEL: + def __get__(self, instance: MODEL_INSTANCE, owner: MODEL_CLASS) -> MODEL_INSTANCE: ... - def __get__(self, instance: Optional["Model"], owner: Type["Model"]): + def __get__(self, instance: Optional[MODEL_INSTANCE], owner: MODEL_CLASS): ... def describe(self, serializable: bool) -> dict: @@ -304,7 +301,7 @@ def describe(self, serializable: bool) -> dict: return desc -class ForeignKeyFieldInstance(RelationalField[MODEL]): +class ForeignKeyFieldInstance(RelationalField[MODEL_INSTANCE]): def __init__( self, model_name: str, @@ -330,10 +327,10 @@ def describe(self, serializable: bool) -> dict: return desc -class BackwardFKRelation(RelationalField[MODEL]): +class BackwardFKRelation(RelationalField[MODEL_INSTANCE]): def __init__( self, - field_type: "Type[MODEL]", + field_type: MODEL_CLASS, relation_field: str, relation_source_field: str, null: bool, @@ -346,7 +343,7 @@ def __init__( self.description: Optional[str] = description -class OneToOneFieldInstance(ForeignKeyFieldInstance[MODEL]): +class OneToOneFieldInstance(ForeignKeyFieldInstance[MODEL_INSTANCE]): def __init__( self, model_name: str, @@ -359,11 +356,11 @@ def __init__( super().__init__(model_name, related_name, on_delete, unique=True, **kwargs) -class BackwardOneToOneRelation(BackwardFKRelation[MODEL]): +class BackwardOneToOneRelation(BackwardFKRelation[MODEL_INSTANCE]): pass -class ManyToManyFieldInstance(RelationalField[MODEL]): +class ManyToManyFieldInstance(RelationalField[MODEL_INSTANCE]): field_type = ManyToManyRelation def __init__( @@ -374,7 +371,7 @@ def __init__( backward_key: str = "", related_name: str = "", on_delete: str = CASCADE, - field_type: "Type[MODEL]" = None, # type: ignore + field_type: MODEL_CLASS = None, # type: ignore **kwargs: Any, ) -> None: # TODO: rename through to through_table @@ -408,7 +405,7 @@ def OneToOneField( on_delete: str = CASCADE, db_constraint: bool = True, **kwargs: Any, -) -> "OneToOneRelation[MODEL]": +) -> "OneToOneRelation[MODEL_INSTANCE]": """ OneToOne relation field. @@ -457,7 +454,7 @@ def ForeignKeyField( on_delete: str = CASCADE, db_constraint: bool = True, **kwargs: Any, -) -> "ForeignKeyRelation[MODEL]": +) -> "ForeignKeyRelation[MODEL_INSTANCE]": """ ForeignKey relation field. @@ -509,7 +506,7 @@ def ManyToManyField( on_delete: str = CASCADE, db_constraint: bool = True, **kwargs: Any, -) -> "ManyToManyRelation[MODEL]": +) -> "ManyToManyRelation[MODEL_INSTANCE]": """ ManyToMany relation field. @@ -565,24 +562,24 @@ def ManyToManyField( ) -OneToOneNullableRelation = Optional[OneToOneFieldInstance[MODEL]] +OneToOneNullableRelation = Optional[OneToOneFieldInstance[MODEL_INSTANCE]] """ Type hint for the result of accessing the :func:`.OneToOneField` field in the model when obtained model can be nullable. """ -OneToOneRelation = OneToOneFieldInstance[MODEL] +OneToOneRelation = OneToOneFieldInstance[MODEL_INSTANCE] """ Type hint for the result of accessing the :func:`.OneToOneField` field in the model. """ -ForeignKeyNullableRelation = Optional[ForeignKeyFieldInstance[MODEL]] +ForeignKeyNullableRelation = Optional[ForeignKeyFieldInstance[MODEL_INSTANCE]] """ Type hint for the result of accessing the :func:`.ForeignKeyField` field in the model when obtained model can be nullable. """ -ForeignKeyRelation = ForeignKeyFieldInstance[MODEL] +ForeignKeyRelation = ForeignKeyFieldInstance[MODEL_INSTANCE] """ Type hint for the result of accessing the :func:`.ForeignKeyField` field in the model. """ diff --git a/tortoise/filters.py b/tortoise/filters.py index b38f1d345..23f8520b4 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -11,7 +11,7 @@ from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance if TYPE_CHECKING: # pragma: nocoverage - from tortoise.models import Model + from tortoise.models import MODEL_INSTANCE ############################################################################## @@ -43,31 +43,31 @@ def escape_like(val: str) -> str: ############################################################################## # Encoders -# Should be type: (Any, instance: "Model", field: Field) -> type: +# Should be type: (Any, instance: "MODEL_INSTANCE", field: Field) -> type: ############################################################################## -def list_encoder(values: Iterable[Any], instance: "Model", field: Field) -> list: +def list_encoder(values: Iterable[Any], instance: "MODEL_INSTANCE", field: Field) -> list: """Encodes an iterable of a given field into a database-compatible format.""" return [field.to_db_value(element, instance) for element in values] -def related_list_encoder(values: Iterable[Any], instance: "Model", field: Field) -> list: +def related_list_encoder(values: Iterable[Any], instance: "MODEL_INSTANCE", field: Field) -> list: return [ field.to_db_value(element.pk if hasattr(element, "pk") else element, instance) for element in values ] -def bool_encoder(value: Any, instance: "Model", field: Field) -> bool: +def bool_encoder(value: Any, instance: "MODEL_INSTANCE", field: Field) -> bool: return bool(value) -def string_encoder(value: Any, instance: "Model", field: Field) -> str: +def string_encoder(value: Any, instance: "MODEL_INSTANCE", field: Field) -> str: return str(value) -def json_encoder(value: Any, instance: "Model", field: Field) -> Dict: +def json_encoder(value: Any, instance: "MODEL_INSTANCE", field: Field) -> Dict: return value diff --git a/tortoise/indexes.py b/tortoise/indexes.py index 7ccf49fbb..6e25aeb66 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Tuple from pypika.terms import Term, ValueWrapper if TYPE_CHECKING: - from tortoise import Model + from tortoise import MODEL_CLASS from tortoise.backends.base.schema_generator import BaseSchemaGenerator @@ -38,7 +38,7 @@ def __init__( self.expressions = expressions self.extra = "" - def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool): + def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "MODEL_CLASS", safe: bool): if self.fields: return self.INDEX_CREATE_TEMPLATE.format( exists="IF NOT EXISTS " if safe else "", @@ -61,7 +61,7 @@ def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", extra=self.extra, ) - def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]"): + def index_name(self, schema_generator: "BaseSchemaGenerator", model: "MODEL_CLASS"): return self.name or schema_generator._generate_index_name("idx", model, self.fields) diff --git a/tortoise/models.py b/tortoise/models.py index d195fc0b7..d16e7b50f 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -16,7 +16,7 @@ Tuple, Type, TypeVar, - Union, + Union, TypeAlias, ) from pypika import Order, Query, Table @@ -62,14 +62,16 @@ from tortoise.signals import Signals from tortoise.transactions import in_transaction -MODEL = TypeVar("MODEL", bound="Model") +MODEL_INSTANCE = TypeVar("MODEL_INSTANCE", bound="Model") +MODEL_CLASS: TypeAlias = Type[MODEL_INSTANCE] +MODEL_META_INSTANCE = TypeVar("MODEL_META_INSTANCE", bound="ModelMeta") EMPTY = object() # TODO: Define Filter type object. Possibly tuple? -def get_together(meta: "Model.Meta", together: str) -> Tuple[Tuple[str, ...], ...]: +def get_together(meta: MODEL_META_INSTANCE, together: str) -> Tuple[Tuple[str, ...], ...]: _together = getattr(meta, together, ()) if _together and isinstance(_together, (list, tuple)) and isinstance(_together[0], str): @@ -79,7 +81,7 @@ def get_together(meta: "Model.Meta", together: str) -> Tuple[Tuple[str, ...], .. return _together -def prepare_default_ordering(meta: "Model.Meta") -> Tuple[Tuple[str, Order], ...]: +def prepare_default_ordering(meta: MODEL_META_INSTANCE) -> Tuple[Tuple[str, Order], ...]: ordering_list = getattr(meta, "ordering", ()) parsed_ordering = tuple( @@ -90,8 +92,8 @@ def prepare_default_ordering(meta: "Model.Meta") -> Tuple[Tuple[str, Order], ... def _fk_setter( - self: "Model", - value: "Optional[Model]", + self: MODEL_INSTANCE, + value: Optional[MODEL_INSTANCE], _key: str, relation_field: str, to_field: str, @@ -101,7 +103,7 @@ def _fk_setter( def _fk_getter( - self: "Model", _key: str, ftype: "Type[Model]", relation_field: str, to_field: str + self: MODEL_INSTANCE, _key: str, ftype: MODEL_CLASS, relation_field: str, to_field: str ) -> Awaitable: try: return getattr(self, _key) @@ -113,7 +115,7 @@ def _fk_getter( def _rfk_getter( - self: "Model", _key: str, ftype: "Type[Model]", frelfield: str, from_field: str + self: MODEL_INSTANCE, _key: str, ftype: MODEL_CLASS, frelfield: str, from_field: str ) -> ReverseRelation: val = getattr(self, _key, None) if val is None: @@ -123,8 +125,8 @@ def _rfk_getter( def _ro2o_getter( - self: "Model", _key: str, ftype: "Type[Model]", frelfield: str, from_field: str -) -> "QuerySetSingle[Optional[Model]]": + self: MODEL_INSTANCE, _key: str, ftype: MODEL_CLASS, frelfield: str, from_field: str +) -> QuerySetSingle[Optional[MODEL_INSTANCE]]: if hasattr(self, _key): return getattr(self, _key) @@ -134,7 +136,7 @@ def _ro2o_getter( def _m2m_getter( - self: "Model", _key: str, field_object: ManyToManyFieldInstance + self: MODEL_INSTANCE, _key: str, field_object: ManyToManyFieldInstance ) -> ManyToManyRelation: val = getattr(self, _key, None) if val is None: @@ -143,7 +145,7 @@ def _m2m_getter( return val -def _get_comments(cls: "Type[Model]") -> Dict[str, str]: +def _get_comments(cls: MODEL_CLASS) -> Dict[str, str]: """ Get comments exactly before attributes @@ -214,7 +216,7 @@ class MetaInfo: "_ordering_validated", ) - def __init__(self, meta: "Model.Meta") -> None: + def __init__(self, meta: MODEL_META_INSTANCE) -> None: self.abstract: bool = getattr(meta, "abstract", False) self.manager: Manager = getattr(meta, "manager", Manager()) self.db_table: str = getattr(meta, "table", "") @@ -244,7 +246,7 @@ def __init__(self, meta: "Model.Meta") -> None: self.basetable: Table = Table("") self.pk_attr: str = getattr(meta, "pk_attr", "") self.generated_db_fields: Tuple[str] = None # type: ignore - self._model: Type["Model"] = None # type: ignore + self._model: MODEL_CLASS = None # type: ignore self.table_description: str = getattr(meta, "table_description", "") self.pk: Field = None # type: ignore self.db_pk_column: str = "" @@ -494,7 +496,7 @@ def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict): fk_fields: Set[str] = set() m2m_fields: Set[str] = set() o2o_fields: Set[str] = set() - meta_class: "Model.Meta" = attrs.get("Meta", type("Meta", (), {})) + meta_class: MODEL_META_INSTANCE = attrs.get("Meta", type("Meta", (), {})) pk_attr: str = "id" # Searching for Field attributes in the class hierarchy @@ -642,7 +644,7 @@ def __search_for_field_attributes(base: Type, attrs: dict) -> None: meta.finalise_fields() return new_class - def __getitem__(cls: Type[MODEL], key: Any) -> QuerySetSingle[MODEL]: # type: ignore + def __getitem__(cls: MODEL_CLASS, key: Any) -> QuerySetSingle[MODEL_INSTANCE]: # type: ignore return cls._getbypk(key) # type: ignore @@ -653,7 +655,7 @@ class Model(metaclass=ModelMeta): # I don' like this here, but it makes auto completion and static analysis much happier _meta = MetaInfo(None) # type: ignore - _listeners: Dict[Signals, Dict[Type[MODEL], List[Callable]]] = { # type: ignore + _listeners: Dict[Signals, Dict[MODEL_CLASS, List[Callable]]] = { # type: ignore Signals.pre_save: {}, Signals.post_save: {}, Signals.pre_delete: {}, @@ -713,7 +715,7 @@ def _set_kwargs(self, kwargs: dict) -> Set[str]: return passed_fields @classmethod - def _init_from_db(cls: Type[MODEL], **kwargs: Any) -> MODEL: + def _init_from_db(cls: MODEL_CLASS, **kwargs: Any) -> MODEL_INSTANCE: self = cls.__new__(cls) self._partial = False self._saved_in_db = True @@ -781,13 +783,13 @@ def _set_pk_val(self, value: Any) -> None: """ @classmethod - async def _getbypk(cls: Type[MODEL], key: Any) -> MODEL: + async def _getbypk(cls: MODEL_CLASS, key: Any) -> MODEL_INSTANCE: try: return await cls.get(pk=key) except (DoesNotExist, ValueError): raise KeyError(f"{cls._meta.full_name} has no object {repr(key)}") - def clone(self: MODEL, pk: Any = EMPTY) -> MODEL: + def clone(self: MODEL_INSTANCE, pk: Any = EMPTY) -> MODEL_INSTANCE: """ Create a new clone of the object that when you do a ``.save()`` will create a new record. @@ -810,7 +812,7 @@ def clone(self: MODEL, pk: Any = EMPTY) -> MODEL: obj._saved_in_db = False return obj - def update_from_dict(self: MODEL, data: dict) -> MODEL: + def update_from_dict(self: MODEL_INSTANCE, data: dict) -> MODEL_INSTANCE: """ Updates the current model with the provided dict. This can allow mass-updating a model from a dict, also ensuring that datatype conversions happen. @@ -1035,11 +1037,11 @@ def _choose_db(cls, for_write: bool = False): @classmethod async def get_or_create( - cls: Type[MODEL], + cls: MODEL_CLASS, defaults: Optional[dict] = None, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any, - ) -> Tuple[MODEL, bool]: + ) -> Tuple[MODEL_INSTANCE, bool]: """ Fetches the object if exists (filtering on the provided parameters), else creates an instance with any unspecified parameters as default values. @@ -1072,7 +1074,7 @@ def select_for_update( skip_locked: bool = False, of: Tuple[str, ...] = (), using_db: Optional[BaseDBAsyncClient] = None, - ) -> QuerySet[MODEL]: + ) -> QuerySet[MODEL_INSTANCE]: """ Make QuerySet select for update. @@ -1086,11 +1088,11 @@ def select_for_update( @classmethod async def update_or_create( - cls: Type[MODEL], + cls: MODEL_CLASS, defaults: Optional[dict] = None, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any, - ) -> Tuple[MODEL, bool]: + ) -> Tuple[MODEL_INSTANCE, bool]: """ A convenience method for updating an object with the given kwargs, creating a new one if necessary. @@ -1114,8 +1116,8 @@ async def update_or_create( @classmethod async def create( - cls: Type[MODEL], using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any - ) -> MODEL: + cls: MODEL_CLASS, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + ) -> MODEL_INSTANCE: """ Create a record in the DB and returns the object. @@ -1141,12 +1143,12 @@ async def create( @classmethod def bulk_update( - cls: Type[MODEL], - objects: Iterable[MODEL], + cls: MODEL_CLASS, + objects: Iterable[MODEL_INSTANCE], fields: Iterable[str], batch_size: Optional[int] = None, using_db: Optional[BaseDBAsyncClient] = None, - ) -> "BulkUpdateQuery": + ) -> BulkUpdateQuery: """ Update the given fields in each of the given objects in the database. This method efficiently updates the given fields on the provided model instances, generally with one query. @@ -1174,11 +1176,11 @@ def bulk_update( @classmethod async def in_bulk( - cls: Type[MODEL], + cls: MODEL_CLASS, id_list: Iterable[Union[str, int]], field_name: str = "pk", using_db: Optional[BaseDBAsyncClient] = None, - ) -> Dict[str, MODEL]: + ) -> Dict[str, MODEL_INSTANCE]: """ Return a dictionary mapping each of the given IDs to the object with that ID. If `id_list` isn't provided, evaluate the entire QuerySet. @@ -1192,14 +1194,14 @@ async def in_bulk( @classmethod def bulk_create( - cls: Type[MODEL], - objects: Iterable[MODEL], + cls: MODEL_CLASS, + objects: Iterable[MODEL_INSTANCE], batch_size: Optional[int] = None, ignore_conflicts: bool = False, update_fields: Optional[Iterable[str]] = None, on_conflict: Optional[Iterable[str]] = None, using_db: Optional[BaseDBAsyncClient] = None, - ) -> "BulkCreateQuery": + ) -> BulkCreateQuery: """ Bulk insert operation: @@ -1236,8 +1238,8 @@ def bulk_create( @classmethod def first( - cls: Type[MODEL], using_db: Optional[BaseDBAsyncClient] = None - ) -> QuerySetSingle[Optional[MODEL]]: + cls: MODEL_CLASS, using_db: Optional[BaseDBAsyncClient] = None + ) -> QuerySetSingle[Optional[MODEL_INSTANCE]]: """ Generates a QuerySet that returns the first record. """ @@ -1245,7 +1247,7 @@ def first( return cls._meta.manager.get_queryset().using_db(db).first() @classmethod - def filter(cls: Type[MODEL], *args: Q, **kwargs: Any) -> QuerySet[MODEL]: + def filter(cls: MODEL_CLASS, *args: Q, **kwargs: Any) -> QuerySet[MODEL_INSTANCE]: """ Generates a QuerySet with the filter applied. @@ -1255,7 +1257,7 @@ def filter(cls: Type[MODEL], *args: Q, **kwargs: Any) -> QuerySet[MODEL]: return cls._meta.manager.get_queryset().filter(*args, **kwargs) @classmethod - def exclude(cls: Type[MODEL], *args: Q, **kwargs: Any) -> QuerySet[MODEL]: + def exclude(cls: MODEL_CLASS, *args: Q, **kwargs: Any) -> QuerySet[MODEL_INSTANCE]: """ Generates a QuerySet with the exclude applied. @@ -1265,7 +1267,7 @@ def exclude(cls: Type[MODEL], *args: Q, **kwargs: Any) -> QuerySet[MODEL]: return cls._meta.manager.get_queryset().exclude(*args, **kwargs) @classmethod - def annotate(cls: Type[MODEL], **kwargs: Union[Function, Term]) -> QuerySet[MODEL]: + def annotate(cls: MODEL_CLASS, **kwargs: Union[Function, Term]) -> QuerySet[MODEL_INSTANCE]: """ Annotates the result set with extra Functions/Aggregations/Expressions. @@ -1274,7 +1276,7 @@ def annotate(cls: Type[MODEL], **kwargs: Union[Function, Term]) -> QuerySet[MODE return cls._meta.manager.get_queryset().annotate(**kwargs) @classmethod - def all(cls: Type[MODEL], using_db: Optional[BaseDBAsyncClient] = None) -> QuerySet[MODEL]: + def all(cls: MODEL_CLASS, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySet[MODEL_INSTANCE]: """ Returns the complete QuerySet. """ @@ -1283,8 +1285,8 @@ def all(cls: Type[MODEL], using_db: Optional[BaseDBAsyncClient] = None) -> Query @classmethod def get( - cls: Type[MODEL], *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any - ) -> QuerySetSingle[MODEL]: + cls: MODEL_CLASS, *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + ) -> QuerySetSingle[MODEL_INSTANCE]: """ Fetches a single record for a Model type using the provided filter parameters. @@ -1319,7 +1321,7 @@ def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQ @classmethod def exists( - cls: Type[MODEL], *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + cls: MODEL_CLASS, *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any ) -> ExistsQuery: """ Return True/False whether record exists with the provided filter parameters. @@ -1337,8 +1339,8 @@ def exists( @classmethod def get_or_none( - cls: Type[MODEL], *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any - ) -> QuerySetSingle[Optional[MODEL]]: + cls: MODEL_CLASS, *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + ) -> QuerySetSingle[Optional[MODEL_INSTANCE]]: """ Fetches a single record for a Model type using the provided filter parameters or None. @@ -1356,7 +1358,7 @@ def get_or_none( @classmethod async def fetch_for_list( cls, - instance_list: "Iterable[Model]", + instance_list: Iterable[MODEL_INSTANCE], *args: Any, using_db: Optional[BaseDBAsyncClient] = None, ) -> None: @@ -1492,8 +1494,8 @@ def describe(cls, serializable: bool = True) -> dict: ], } - def __await__(self: MODEL) -> Generator[Any, None, MODEL]: - async def _self() -> MODEL: + def __await__(self: MODEL_INSTANCE) -> Generator[Any, None, MODEL_INSTANCE]: + async def _self() -> MODEL_INSTANCE: return self return _self().__await__() diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 17c42f5c6..b1a39163f 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -17,7 +17,7 @@ TypeVar, Union, cast, - overload, + overload, TypeAlias, ) from pypika import JoinType, Order, Table @@ -50,9 +50,12 @@ QUERY: QueryBuilder = QueryBuilder() if TYPE_CHECKING: # pragma: nocoverage - from tortoise.models import Model + pass + +MODEL_INSTANCE = TypeVar("MODEL_INSTANCE", bound="Model") +MODEL_CLASS: TypeAlias = Type[MODEL_INSTANCE] +MODEL_META_INSTANCE = TypeVar("MODEL_META_INSTANCE", bound="ModelMeta") -MODEL = TypeVar("MODEL", bound="Model") T_co = TypeVar("T_co", covariant=True) SINGLE = TypeVar("SINGLE", bound=bool) @@ -85,12 +88,12 @@ def values(self, *args: str, **kwargs: str) -> "ValuesQuery[Literal[True]]": ... # pragma: nocoverage -class AwaitableQuery(Generic[MODEL]): +class AwaitableQuery(Generic[MODEL_INSTANCE]): __slots__ = ("_joined_tables", "query", "model", "_db", "capabilities", "_annotations") - def __init__(self, model: Type[MODEL]) -> None: + def __init__(self, model: MODEL_CLASS) -> None: self._joined_tables: List[Table] = [] - self.model: "Type[Model]" = model + self.model: MODEL_CLASS = model self.query: QueryBuilder = QUERY self._db: BaseDBAsyncClient = None # type: ignore self.capabilities: Capabilities = model._meta.db.capabilities @@ -112,7 +115,7 @@ def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: def resolve_filters( self, - model: "Type[Model]", + model: MODEL_CLASS, q_objects: List[Q], annotations: Dict[str, Any], custom_filters: Dict[str, Dict[str, Any]], @@ -170,7 +173,7 @@ def _resolve_ordering_string(ordering: str) -> Tuple[str, Order]: def resolve_ordering( self, - model: "Type[Model]", + model: MODEL_CLASS, table: Table, orderings: Iterable[Tuple[str, str]], annotations: Dict[str, Any], @@ -267,7 +270,7 @@ async def _execute(self) -> Any: raise NotImplementedError() # pragma: nocoverage -class QuerySet(AwaitableQuery[MODEL]): +class QuerySet(AwaitableQuery[MODEL_INSTANCE]): __slots__ = ( "fields", "_prefetch_map", @@ -295,7 +298,7 @@ class QuerySet(AwaitableQuery[MODEL]): "_force_indexes", ) - def __init__(self, model: Type[MODEL]) -> None: + def __init__(self, model: MODEL_CLASS) -> None: super().__init__(model) self.fields: Set[str] = model._meta.db_fields self._prefetch_map: Dict[str, Set[Union[str, Prefetch]]] = {} @@ -318,12 +321,12 @@ def __init__(self, model: Type[MODEL]) -> None: self._select_for_update_of: Set[str] = set() self._select_related: Set[str] = set() self._select_related_idx: List[ - Tuple["Type[Model]", int, str, "Type[Model]", Iterable[Optional[str]]] + Tuple[MODEL_CLASS, int, str, MODEL_CLASS, Iterable[Optional[str]]] ] = [] # format with: model,idx,model_name,parent_model self._force_indexes: Set[str] = set() self._use_indexes: Set[str] = set() - def _clone(self) -> "QuerySet[MODEL]": + def _clone(self) -> "QuerySet[MODEL_INSTANCE]": queryset = self.__class__.__new__(self.__class__) queryset.fields = self.fields queryset.model = self.model @@ -356,7 +359,7 @@ def _clone(self) -> "QuerySet[MODEL]": queryset._use_indexes = self._use_indexes return queryset - def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> "QuerySet[MODEL]": + def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> "QuerySet[MODEL_INSTANCE]": queryset = self._clone() for arg in args: if not isinstance(arg, Q): @@ -374,7 +377,7 @@ def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> "QuerySet return queryset - def filter(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]": + def filter(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL_INSTANCE]": """ Filters QuerySet by given kwargs. You can filter by related objects like this: @@ -386,13 +389,13 @@ def filter(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]": """ return self._filter_or_exclude(negate=False, *args, **kwargs) - def exclude(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]": + def exclude(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL_INSTANCE]": """ Same as .filter(), but with appends all args with NOT """ return self._filter_or_exclude(negate=True, *args, **kwargs) - def order_by(self, *orderings: str) -> "QuerySet[MODEL]": + def order_by(self, *orderings: str) -> "QuerySet[MODEL_INSTANCE]": """ Accept args to filter by in format like this: @@ -419,7 +422,7 @@ def order_by(self, *orderings: str) -> "QuerySet[MODEL]": queryset._orderings = new_ordering return queryset - def limit(self, limit: int) -> "QuerySet[MODEL]": + def limit(self, limit: int) -> "QuerySet[MODEL_INSTANCE]": """ Limits QuerySet to given length. @@ -432,7 +435,7 @@ def limit(self, limit: int) -> "QuerySet[MODEL]": queryset._limit = limit return queryset - def offset(self, offset: int) -> "QuerySet[MODEL]": + def offset(self, offset: int) -> "QuerySet[MODEL_INSTANCE]": """ Query offset for QuerySet. @@ -447,7 +450,7 @@ def offset(self, offset: int) -> "QuerySet[MODEL]": queryset._limit = 1000000 return queryset - def distinct(self) -> "QuerySet[MODEL]": + def distinct(self) -> "QuerySet[MODEL_INSTANCE]": """ Make QuerySet distinct. @@ -460,7 +463,7 @@ def distinct(self) -> "QuerySet[MODEL]": def select_for_update( self, nowait: bool = False, skip_locked: bool = False, of: Tuple[str, ...] = () - ) -> "QuerySet[MODEL]": + ) -> "QuerySet[MODEL_INSTANCE]": """ Make QuerySet select for update. @@ -476,7 +479,7 @@ def select_for_update( return queryset return self - def annotate(self, **kwargs: Union[Expression, Term]) -> "QuerySet[MODEL]": + def annotate(self, **kwargs: Union[Expression, Term]) -> "QuerySet[MODEL_INSTANCE]": """ Annotate result with aggregation or function result. @@ -492,7 +495,7 @@ def annotate(self, **kwargs: Union[Expression, Term]) -> "QuerySet[MODEL]": queryset._custom_filters.update(get_filters_for_field(key, None, key)) return queryset - def group_by(self, *fields: str) -> "QuerySet[MODEL]": + def group_by(self, *fields: str) -> "QuerySet[MODEL_INSTANCE]": """ Make QuerySet returns list of dict or tuple with group by. @@ -655,7 +658,7 @@ def exists(self) -> "ExistsQuery": use_indexes=self._use_indexes, ) - def all(self) -> "QuerySet[MODEL]": + def all(self) -> "QuerySet[MODEL_INSTANCE]": """ Return the whole QuerySet. Essentially a no-op except as the only operation. @@ -668,7 +671,7 @@ def raw(self, sql: str) -> "RawSQLQuery": """ return RawSQLQuery(model=self.model, db=self._db, sql=sql) - def first(self) -> QuerySetSingle[Optional[MODEL]]: + def first(self) -> QuerySetSingle[Optional[MODEL_INSTANCE]]: """ Limit queryset to one object and return one object instead of list. """ @@ -677,7 +680,7 @@ def first(self) -> QuerySetSingle[Optional[MODEL]]: queryset._single = True return queryset # type: ignore - def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]: + def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL_INSTANCE]: """ Fetch exactly one object matching the parameters. """ @@ -689,7 +692,7 @@ def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]: async def in_bulk( self, id_list: Iterable[Union[str, int]], field_name: str - ) -> Dict[str, MODEL]: + ) -> Dict[str, MODEL_INSTANCE]: """ Return a dictionary mapping each of the given IDs to the object with that ID. If `id_list` isn't provided, evaluate the entire QuerySet. @@ -702,7 +705,7 @@ async def in_bulk( def bulk_create( self, - objects: Iterable[MODEL], + objects: Iterable[MODEL_INSTANCE], batch_size: Optional[int] = None, ignore_conflicts: bool = False, update_fields: Optional[Iterable[str]] = None, @@ -740,7 +743,7 @@ def bulk_create( def bulk_update( self, - objects: Iterable[MODEL], + objects: Iterable[MODEL_INSTANCE], fields: Iterable[str], batch_size: Optional[int] = None, ) -> "BulkUpdateQuery": @@ -768,7 +771,7 @@ def bulk_update( batch_size=batch_size, ) - def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[Optional[MODEL]]: + def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[Optional[MODEL_INSTANCE]]: """ Fetch exactly one object matching the parameters. """ @@ -777,7 +780,7 @@ def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[Optional[MODEL] queryset._single = True return queryset # type: ignore - def only(self, *fields_for_select: str) -> "QuerySet[MODEL]": + def only(self, *fields_for_select: str) -> "QuerySet[MODEL_INSTANCE]": """ Fetch ONLY the specified fields to create a partial model. @@ -799,7 +802,7 @@ def only(self, *fields_for_select: str) -> "QuerySet[MODEL]": queryset._fields_for_select = fields_for_select return queryset - def select_related(self, *fields: str) -> "QuerySet[MODEL]": + def select_related(self, *fields: str) -> "QuerySet[MODEL_INSTANCE]": """ Return a new QuerySet instance that will select related objects. @@ -812,7 +815,7 @@ def select_related(self, *fields: str) -> "QuerySet[MODEL]": queryset._select_related.add(field) return queryset - def force_index(self, *index_names: str) -> "QuerySet[MODEL]": + def force_index(self, *index_names: str) -> "QuerySet[MODEL_INSTANCE]": """ The FORCE INDEX hint acts like USE INDEX (index_list), with the addition that a table scan is assumed to be very expensive. @@ -824,7 +827,7 @@ def force_index(self, *index_names: str) -> "QuerySet[MODEL]": return queryset return self - def use_index(self, *index_names: str) -> "QuerySet[MODEL]": + def use_index(self, *index_names: str) -> "QuerySet[MODEL_INSTANCE]": """ The USE INDEX (index_list) hint tells MySQL to use only one of the named indexes to find rows in the table. """ @@ -835,7 +838,7 @@ def use_index(self, *index_names: str) -> "QuerySet[MODEL]": return queryset return self - def prefetch_related(self, *args: Union[str, Prefetch]) -> "QuerySet[MODEL]": + def prefetch_related(self, *args: Union[str, Prefetch]) -> "QuerySet[MODEL_INSTANCE]": """ Like ``.fetch_related()`` on instance, but works on all objects in QuerySet. @@ -886,7 +889,7 @@ async def explain(self) -> Any: self.query ) - def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": + def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL_INSTANCE]": """ Executes query in provided db client. Useful for transactions workaround. @@ -897,7 +900,7 @@ def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": def _join_table_with_select_related( self, - model: "Type[Model]", + model: MODEL_CLASS, table: Table, field: str, forwarded_fields: str, @@ -994,17 +997,17 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - def __await__(self) -> Generator[Any, None, List[MODEL]]: + def __await__(self) -> Generator[Any, None, List[MODEL_INSTANCE]]: if self._db is None: self._db = self._choose_db(self._select_for_update) # type: ignore self._make_query() return self._execute().__await__() - async def __aiter__(self) -> AsyncIterator[MODEL]: + async def __aiter__(self) -> AsyncIterator[MODEL_INSTANCE]: for val in await self: yield val - async def _execute(self) -> List[MODEL]: + async def _execute(self) -> List[MODEL_INSTANCE]: instance_list = await self._db.executor_class( model=self.model, db=self._db, @@ -1036,7 +1039,7 @@ class UpdateQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, update_kwargs: Dict[str, Any], db: BaseDBAsyncClient, q_objects: List[Q], @@ -1122,7 +1125,7 @@ class DeleteQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, db: BaseDBAsyncClient, q_objects: List[Q], annotations: Dict[str, Any], @@ -1174,7 +1177,7 @@ class ExistsQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, db: BaseDBAsyncClient, q_objects: List[Q], annotations: Dict[str, Any], @@ -1232,7 +1235,7 @@ class CountQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, db: BaseDBAsyncClient, q_objects: List[Q], annotations: Dict[str, Any], @@ -1287,12 +1290,12 @@ class FieldSelectQuery(AwaitableQuery): # pylint: disable=W0223 __slots__ = ("annotations",) - def __init__(self, model: Type[MODEL], annotations: Dict[str, Any]) -> None: + def __init__(self, model: MODEL_CLASS, annotations: Dict[str, Any]) -> None: super().__init__(model) self.annotations = annotations def _join_table_with_forwarded_fields( - self, model: Type[MODEL], table: Table, field: str, forwarded_fields: str + self, model: MODEL_CLASS, table: Table, field: str, forwarded_fields: str ) -> Tuple[Table, str]: if field in model._meta.fields_db_projection and not forwarded_fields: return table, model._meta.fields_db_projection[field] @@ -1351,7 +1354,7 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: raise FieldError(f'Unknown field "{field}" for model "{self.model.__name__}"') - def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable: + def resolve_to_python_value(self, model: MODEL_CLASS, field: str) -> Callable: if field in model._meta.fetch_fields: # return as is to get whole model objects return lambda x: x @@ -1415,7 +1418,7 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, db: BaseDBAsyncClient, q_objects: List[Q], single: bool, @@ -1547,7 +1550,7 @@ class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, db: BaseDBAsyncClient, q_objects: List[Q], single: bool, @@ -1658,7 +1661,7 @@ async def _execute(self) -> Union[List[dict], Dict]: class RawSQLQuery(AwaitableQuery): __slots__ = ("_sql", "_db") - def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str): + def __init__(self, model: MODEL_CLASS, db: BaseDBAsyncClient, sql: str): super().__init__(model) self._sql = sql self._db = db @@ -1673,7 +1676,7 @@ async def _execute(self) -> Any: ).execute_select(self.query) return instance_list - def __await__(self) -> Generator[Any, None, List[MODEL]]: + def __await__(self) -> Generator[Any, None, List[MODEL_INSTANCE]]: if self._db is None: self._db = self._choose_db() # type: ignore self._make_query() @@ -1685,14 +1688,14 @@ class BulkUpdateQuery(UpdateQuery): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, db: BaseDBAsyncClient, q_objects: List[Q], annotations: Dict[str, Any], custom_filters: Dict[str, Dict[str, Any]], limit: Optional[int], orderings: List[Tuple[str, str]], - objects: Iterable[MODEL], + objects: Iterable[MODEL_INSTANCE], fields: Iterable[str], batch_size: Optional[int] = None, ): @@ -1768,9 +1771,9 @@ class BulkCreateQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: MODEL_CLASS, db: BaseDBAsyncClient, - objects: Iterable[MODEL], + objects: Iterable[MODEL_INSTANCE], batch_size: Optional[int] = None, ignore_conflicts: bool = False, update_fields: Optional[Iterable[str]] = None, @@ -1816,7 +1819,7 @@ def _make_query(self) -> None: ) self.insert_query = self.insert_query.do_update(update_field) # type:ignore - async def _execute(self) -> List[MODEL]: + async def _execute(self) -> List[MODEL_INSTANCE]: for instance_chunk in chunk(self.objects, self.batch_size): values_lists_all = [] values_lists = [] @@ -1845,7 +1848,7 @@ async def _execute(self) -> List[MODEL]: await self._db.execute_many(str(self.insert_query), values_lists) return self.objects - def __await__(self) -> Generator[Any, None, List[MODEL]]: + def __await__(self) -> Generator[Any, None, List[MODEL_INSTANCE]]: if self._db is None: self._db = self._choose_db(True) # type: ignore self._make_query() diff --git a/tortoise/router.py b/tortoise/router.py index 7ce5f8503..0c89f1454 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -4,7 +4,7 @@ from tortoise.exceptions import ConfigurationError if TYPE_CHECKING: - from tortoise import BaseDBAsyncClient, Model + from tortoise import BaseDBAsyncClient, MODEL_CLASS class ConnectionRouter: @@ -26,16 +26,16 @@ def _router_func(self, model: Type["Model"], action: str): if chosen_db: return chosen_db - def _db_route(self, model: Type["Model"], action: str): + def _db_route(self, model: "MODEL_CLASS", action: str): try: return connections.get(self._router_func(model, action)) except ConfigurationError: return None - def db_for_read(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]: + def db_for_read(self, model: "MODEL_CLASS") -> Optional["BaseDBAsyncClient"]: return self._db_route(model, "db_for_read") - def db_for_write(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]: + def db_for_write(self, model: "MODEL_CLASS") -> Optional["BaseDBAsyncClient"]: return self._db_route(model, "db_for_write")