Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MODEL_INSTANCE TypeVar and MODEL_CLASS TypeAlias #1352

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
)
from tortoise.filters import get_m2m_filters
from tortoise.log import logger
from tortoise.models import Model, ModelMeta
from tortoise.models import Model, ModelMeta, MODEL_CLASS, MODEL_INSTANCE
from tortoise.utils import generate_schema_for_client


class Tortoise:
apps: Dict[str, Dict[str, Type["Model"]]] = {}
apps: Dict[str, Dict[str, MODEL_CLASS]] = {}
_inited: bool = False

@classmethod
Expand All @@ -46,7 +46,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient:

@classmethod
def describe_model(
cls, model: Type["Model"], serializable: bool = True
cls, model: MODEL_CLASS, serializable: bool = True
) -> dict: # pragma: nocoverage
"""
Describes the given list of models or ALL registered models.
Expand All @@ -64,15 +64,15 @@ def describe_model(
This is deprecated, please use :meth:`tortoise.models.Model.describe` instead
"""
warnings.warn(
"Tortoise.describe_model(<MODEL>) is deprecated, please use <MODEL>.describe() instead",
"Tortoise.describe_model(<MODEL_INSTANCE>) is deprecated, please use <MODEL_INSTANCE>.describe() instead",
DeprecationWarning,
stacklevel=2,
)
return model.describe(serializable=serializable)

@classmethod
def describe_models(
cls, models: Optional[List[Type["Model"]]] = None, serializable: bool = True
cls, models: Optional[List[MODEL_CLASS]] = None, serializable: bool = True
) -> Dict[str, dict]:
"""
Describes the given list of models or ALL registered models.
Expand Down Expand Up @@ -108,7 +108,7 @@ def describe_models(

@classmethod
def _init_relations(cls) -> None:
def get_related_model(related_app_name: str, related_model_name: str) -> Type["Model"]:
def get_related_model(related_app_name: str, related_model_name: str) -> MODEL_CLASS:
"""
Test, if app and model really exist. Throws a ConfigurationError with a hopefully
helpful message. If successful, returns the requested model.
Expand Down Expand Up @@ -345,7 +345,7 @@ def split_reference(reference: str) -> Tuple[str, str]:
@classmethod
def _discover_models(
cls, models_path: Union[ModuleType, str], app_label: str
) -> List[Type["Model"]]:
) -> List[MODEL_CLASS]:
if isinstance(models_path, ModuleType):
module = models_path
else:
Expand Down Expand Up @@ -390,7 +390,7 @@ def init_models(

:raises ConfigurationError: If models are invalid.
"""
app_models: List[Type[Model]] = []
app_models: List[MODEL_CLASS] = []
for models_path in models_paths:
app_models += cls._discover_models(models_path, app_label)

Expand Down
4 changes: 2 additions & 2 deletions tortoise/backends/asyncpg/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import asyncpg

from tortoise import Model
from tortoise import MODEL_INSTANCE
from tortoise.backends.base_postgres.executor import BasePostgresExecutor


class AsyncpgExecutor(BasePostgresExecutor):
async def _process_insert_result(
self, instance: Model, results: Optional[asyncpg.Record]
self, instance: MODEL_INSTANCE, results: Optional[asyncpg.Record]
) -> None:
return await super()._process_insert_result(instance, results)
50 changes: 25 additions & 25 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

if TYPE_CHECKING: # pragma: nocoverage
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.models import Model
from tortoise.models import MODEL_CLASS, MODEL_INSTANCE
from tortoise.query_utils import Prefetch
from tortoise.queryset import QuerySet

Expand All @@ -55,12 +55,12 @@ class BaseExecutor:

def __init__(
self,
model: "Type[Model]",
model: "MODEL_CLASS",
db: "BaseDBAsyncClient",
prefetch_map: "Optional[Dict[str, Set[Union[str, Prefetch]]]]" = None,
prefetch_queries: Optional[Dict[str, List[Tuple[Optional[str], "QuerySet"]]]] = None,
select_related_idx: Optional[
List[Tuple["Type[Model]", int, str, "Type[Model]", Iterable[Optional[str]]]]
List[Tuple["MODEL_CLASS", int, str, "MODEL_CLASS", Iterable[Optional[str]]]]
] = None,
) -> None:
self.model = model
Expand Down Expand Up @@ -136,7 +136,7 @@ async def execute_select(
dict_row = dict(row)
keys = list(dict_row.keys())
values = list(dict_row.values())
instance: "Model" = self.model._init_from_db(
instance: "MODEL_INSTANCE" = self.model._init_from_db(
**dict(zip(keys[:current_idx], values[:current_idx]))
)
instances: Dict[Any, Any] = {path: instance}
Expand Down Expand Up @@ -191,7 +191,7 @@ def _prepare_insert_columns(

@classmethod
def _field_to_db(
cls, field_object: Field, attr: Any, instance: "Union[Type[Model], Model]"
cls, field_object: Field, attr: Any, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"]
) -> Any:
if field_object.__class__ in cls.TO_DB_OVERRIDE:
return cls.TO_DB_OVERRIDE[field_object.__class__](field_object, attr, instance)
Expand All @@ -212,13 +212,13 @@ def _prepare_insert_statement(
query = query.on_conflict().do_nothing()
return query

async def _process_insert_result(self, instance: "Model", results: Any) -> None:
async def _process_insert_result(self, instance: "MODEL_INSTANCE", results: Any) -> None:
raise NotImplementedError() # pragma: nocoverage

def parameter(self, pos: int) -> Parameter:
raise NotImplementedError() # pragma: nocoverage

async def execute_insert(self, instance: "Model") -> None:
async def execute_insert(self, instance: "MODEL_INSTANCE") -> None:
if not instance._custom_generated_pk:
values = [
self.column_map[field_name](getattr(instance, field_name), instance)
Expand All @@ -236,7 +236,7 @@ async def execute_insert(self, instance: "Model") -> None:

async def execute_bulk_insert(
self,
instances: "Iterable[Model]",
instances: Iterable["MODEL_INSTANCE"],
batch_size: Optional[int] = None,
) -> None:
for instance_chunk in chunk(instances, batch_size):
Expand Down Expand Up @@ -300,7 +300,7 @@ def get_update_sql(
return sql

async def execute_update(
self, instance: "Union[Type[Model], Model]", update_fields: Optional[Iterable[str]]
self, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"], update_fields: Optional[Iterable[str]]
) -> int:
values = []
arithmetic_or_function = {}
Expand All @@ -319,7 +319,7 @@ async def execute_update(
)
)[0]

async def execute_delete(self, instance: "Union[Type[Model], Model]") -> int:
async def execute_delete(self, instance: Union["MODEL_CLASS", "MODEL_INSTANCE"]) -> int:
return (
await self.db.execute_query(
self.delete_query, [self.model._meta.pk.to_db_value(instance.pk, instance)]
Expand All @@ -328,10 +328,10 @@ async def execute_delete(self, instance: "Union[Type[Model], Model]") -> int:

async def _prefetch_reverse_relation(
self,
instance_list: "Iterable[Model]",
instance_list: Iterable["MODEL_INSTANCE"],
field: str,
related_query: Tuple[Optional[str], "QuerySet"],
) -> "Iterable[Model]":
) -> Iterable["MODEL_INSTANCE"]:
to_attr, related_query = related_query
related_objects_for_fetch: Dict[str, list] = {}
related_field: BackwardFKRelation = self.model._meta.fields_map[field] # type: ignore
Expand Down Expand Up @@ -373,10 +373,10 @@ async def _prefetch_reverse_relation(

async def _prefetch_reverse_o2o_relation(
self,
instance_list: "Iterable[Model]",
instance_list: Iterable["MODEL_INSTANCE"],
field: str,
related_query: Tuple[Optional[str], "QuerySet"],
) -> "Iterable[Model]":
) -> Iterable["MODEL_INSTANCE"]:
to_attr, related_query = related_query
related_objects_for_fetch: Dict[str, list] = {}
related_field: BackwardOneToOneRelation = self.model._meta.fields_map[field] # type: ignore
Expand Down Expand Up @@ -416,10 +416,10 @@ async def _prefetch_reverse_o2o_relation(

async def _prefetch_m2m_relation(
self,
instance_list: "Iterable[Model]",
instance_list: Iterable["MODEL_INSTANCE"],
field: str,
related_query: Tuple[Optional[str], "QuerySet"],
) -> "Iterable[Model]":
) -> Iterable["MODEL_INSTANCE"]:
to_attr, related_query = related_query
instance_id_set: set = {
self._field_to_db(instance._meta.pk, instance.pk, instance)
Expand Down Expand Up @@ -502,10 +502,10 @@ async def _prefetch_m2m_relation(

async def _prefetch_direct_relation(
self,
instance_list: "Iterable[Model]",
instance_list: Iterable["MODEL_INSTANCE"],
field: str,
related_query: Tuple[Optional[str], "QuerySet"],
) -> "Iterable[Model]":
) -> Iterable["MODEL_INSTANCE"]:
# TODO: This will only work if instance_list is all of same type
# TODO: If that's the case, then we can optimize the key resolver
to_attr, related_query = related_query
Expand Down Expand Up @@ -539,7 +539,7 @@ def _make_prefetch_queries(self) -> None:
to_attr, related_query = self._prefetch_queries[field_name][0]
else:
relation_field = self.model._meta.fields_map[field_name]
related_model: "Type[Model]" = relation_field.related_model # type: ignore
related_model: "MODEL_CLASS" = relation_field.related_model # type: ignore
related_query = related_model.all().using_db(self.db)
related_query.query = copy(related_query.model._meta.basequery)
if forwarded_prefetches:
Expand All @@ -548,10 +548,10 @@ def _make_prefetch_queries(self) -> None:

async def _do_prefetch(
self,
instance_id_list: "Iterable[Model]",
instance_id_list: Iterable["MODEL_INSTANCE"],
field: str,
related_query: Tuple[Optional[str], "QuerySet"],
) -> "Iterable[Model]":
) -> Iterable["MODEL_INSTANCE"]:
if field in self.model._meta.backward_fk_fields:
return await self._prefetch_reverse_relation(instance_id_list, field, related_query)

Expand All @@ -563,8 +563,8 @@ async def _do_prefetch(
return await self._prefetch_direct_relation(instance_id_list, field, related_query)

async def _execute_prefetch_queries(
self, instance_list: "Iterable[Model]"
) -> "Iterable[Model]":
self, instance_list: Iterable["MODEL_INSTANCE"]
) -> Iterable["MODEL_INSTANCE"]:
if instance_list and (self.prefetch_map or self._prefetch_queries):
self._make_prefetch_queries()
prefetch_tasks = []
Expand All @@ -576,8 +576,8 @@ async def _execute_prefetch_queries(
return instance_list

async def fetch_for_list(
self, instance_list: "Iterable[Model]", *args: str
) -> "Iterable[Model]":
self, instance_list: Iterable["MODEL_INSTANCE"], *args: str
) -> Iterable["MODEL_INSTANCE"]:
self.prefetch_map = {}
for relation in args:
first_level_field, __, forwarded_prefetch = relation.partition("__")
Expand Down
19 changes: 9 additions & 10 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hashlib import sha256
from typing import TYPE_CHECKING, Any, List, Set, Type, cast
from typing import TYPE_CHECKING, Any, List, Set, cast

from tortoise.exceptions import ConfigurationError
from tortoise.fields import JSONField, TextField, UUIDField
Expand All @@ -8,8 +8,7 @@
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.fields.relational import ForeignKeyFieldInstance # noqa
from tortoise.fields.relational import ManyToManyFieldInstance
from tortoise.models import Model
from tortoise.models import MODEL_CLASS

# pylint: disable=R0201

Expand Down Expand Up @@ -134,7 +133,7 @@ def _make_hash(*args: str, length: int) -> str:
return sha256(";".join(args).encode("utf-8")).hexdigest()[:length]

def _generate_index_name(
self, prefix: str, model: "Type[Model]", field_names: List[str]
self, prefix: str, model: "MODEL_CLASS", field_names: List[str]
) -> str:
# NOTE: for compatibility, index name should not be longer than 30
# characters (Oracle limit).
Expand All @@ -161,26 +160,26 @@ def _generate_fk_name(
)
return index_name

def _get_index_sql(self, model: "Type[Model]", field_names: List[str], safe: bool) -> str:
def _get_index_sql(self, model: "MODEL_CLASS", field_names: List[str], safe: bool) -> str:
return self.INDEX_CREATE_TEMPLATE.format(
exists="IF NOT EXISTS " if safe else "",
index_name=self._generate_index_name("idx", model, field_names),
table_name=model._meta.db_table,
fields=", ".join([self.quote(f) for f in field_names]),
)

def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: List[str]) -> str:
def _get_unique_constraint_sql(self, model: "MODEL_CLASS", field_names: List[str]) -> str:
return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format(
index_name=self._generate_index_name("uid", model, field_names),
fields=", ".join([self.quote(f) for f in field_names]),
)

def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
def _get_table_sql(self, model: "MODEL_CLASS", safe: bool = True) -> dict:
fields_to_create = []
fields_with_index = []
m2m_tables_for_create = []
references = set()
models_to_create: "List[Type[Model]]" = []
models_to_create: List["MODEL_CLASS"] = []

self._get_models_to_create(models_to_create)
models_tables = [model._meta.db_table for model in models_to_create]
Expand Down Expand Up @@ -402,7 +401,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
"m2m_tables": m2m_tables_for_create,
}

def _get_models_to_create(self, models_to_create: "List[Type[Model]]") -> None:
def _get_models_to_create(self, models_to_create: List["MODEL_CLASS"]) -> None:
from tortoise import Tortoise

for app in Tortoise.apps.values():
Expand All @@ -412,7 +411,7 @@ def _get_models_to_create(self, models_to_create: "List[Type[Model]]") -> None:
models_to_create.append(model)

def get_create_schema_sql(self, safe: bool = True) -> str:
models_to_create: "List[Type[Model]]" = []
models_to_create: List["MODEL_CLASS"] = []

self._get_models_to_create(models_to_create)

Expand Down
4 changes: 2 additions & 2 deletions tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pypika.dialects import PostgreSQLQueryBuilder
from pypika.terms import Term

from tortoise import Model
from tortoise import MODEL_INSTANCE
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.postgres.json_functions import (
postgres_json_contained_by,
Expand Down Expand Up @@ -49,7 +49,7 @@ def _prepare_insert_statement(
query = query.on_conflict().do_nothing()
return query

async def _process_insert_result(self, instance: Model, results: Optional[dict]) -> None:
async def _process_insert_result(self, instance: MODEL_INSTANCE, results: Optional[dict]) -> None:
if results:
generated_fields = self.model._meta.generated_db_fields
db_projection = instance._meta.fields_db_projection_reverse
Expand Down
6 changes: 3 additions & 3 deletions tortoise/backends/mssql/executor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Any, Optional, Type, Union
from typing import Any, Optional, Union

from pypika import Query

from tortoise import Model, fields
from tortoise import MODEL_CLASS, MODEL_INSTANCE, fields
from tortoise.backends.odbc.executor import ODBCExecutor
from tortoise.exceptions import UnSupportedError
from tortoise.fields import BooleanField


def to_db_bool(
self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model]
self: BooleanField, value: Optional[Union[bool, int]], instance: Union[MODEL_CLASS, MODEL_INSTANCE]
) -> Optional[int]:
if value is None:
return None
Expand Down