diff --git a/Makefile b/Makefile index 6f3b4403..02e8af73 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ check: deps build ifneq ($(shell which black),) black --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) endif - ruff $(checkfiles) + ruff check $(checkfiles) mypy $(checkfiles) #pylint -d C,W,R $(checkfiles) #bandit -r $(checkfiles)make @@ -34,7 +34,7 @@ lint: deps build ifneq ($(shell which black),) black --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) endif - ruff $(checkfiles) + ruff check $(checkfiles) mypy $(checkfiles) #pylint $(checkfiles) bandit -r $(checkfiles) diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 3e61b426..1887e63a 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -1,3 +1,4 @@ +import sys import warnings from enum import Enum from typing import ( @@ -23,11 +24,15 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model -VALUE = TypeVar("VALUE") +if sys.version_info >= (3, 11): + from enum import StrEnum +else: # pragma: no cover + + class StrEnum(str, Enum): + __str__ = str.__str__ -class StrEnum(str, Enum): - __str__ = str.__str__ +VALUE = TypeVar("VALUE") class OnDelete(StrEnum): diff --git a/tortoise/models.py b/tortoise/models.py index 451fa4eb..5495a2f7 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -323,11 +323,11 @@ def finalise_fields(self) -> None: | self.o2o_fields ) - generated_fields = [] - for field in self.fields_map.values(): - if not field.generated: - continue - generated_fields.append(field.source_field or field.model_field_name) + generated_fields = [ + (field.source_field or field.model_field_name) + for field in self.fields_map.values() + if field.generated + ] self.generated_db_fields = tuple(generated_fields) self._ordering_validated = True @@ -343,29 +343,28 @@ def _generate_lazy_fk_m2m_fields(self) -> None: fk_field_object: ForeignKeyFieldInstance = self.fields_map[key] # type: ignore relation_field = fk_field_object.source_field to_field = fk_field_object.to_field_instance.model_field_name + property_kwargs = dict( + _key=_key, + relation_field=relation_field, + to_field=to_field, + ) setattr( self._model, key, property( partial( _fk_getter, - _key=_key, ftype=fk_field_object.related_model, - relation_field=relation_field, - to_field=to_field, + **property_kwargs, ), partial( _fk_setter, - _key=_key, - relation_field=relation_field, - to_field=to_field, + **property_kwargs, ), partial( _fk_setter, value=None, - _key=_key, - relation_field=relation_field, - to_field=to_field, + **property_kwargs, ), ), ) @@ -394,29 +393,28 @@ def _generate_lazy_fk_m2m_fields(self) -> None: o2o_field_object: OneToOneFieldInstance = self.fields_map[key] # type: ignore relation_field = o2o_field_object.source_field to_field = o2o_field_object.to_field_instance.model_field_name + property_kwargs = dict( + _key=_key, + relation_field=relation_field, + to_field=to_field, + ) setattr( self._model, key, property( partial( _fk_getter, - _key=_key, ftype=o2o_field_object.related_model, - relation_field=relation_field, - to_field=to_field, + **property_kwargs, ), partial( _fk_setter, - _key=_key, - relation_field=relation_field, - to_field=to_field, + **property_kwargs, ), partial( _fk_setter, value=None, - _key=_key, - relation_field=relation_field, - to_field=to_field, + **property_kwargs, ), ), ) @@ -862,48 +860,23 @@ async def _set_async_default_field(self) -> None: setattr(self, k, await v()) self._await_when_save = {} - async def _pre_delete( - self, - using_db: Optional[BaseDBAsyncClient] = None, - ) -> None: - listeners = [] - cls_listeners = self._listeners.get(Signals.pre_delete, {}).get(self.__class__, []) - for listener in cls_listeners: - listeners.append( - listener( - self.__class__, - self, - using_db, - ) - ) + async def _wait_for_listeners(self, signal: Signals, *listener_args) -> None: + cls_listeners = self._listeners.get(signal, {}).get(self.__class__, []) + listeners = [listener(self.__class__, self, *listener_args) for listener in cls_listeners] await asyncio.gather(*listeners) - async def _post_delete( - self, - using_db: Optional[BaseDBAsyncClient] = None, - ) -> None: - listeners = [] - cls_listeners = self._listeners.get(Signals.post_delete, {}).get(self.__class__, []) - for listener in cls_listeners: - listeners.append( - listener( - self.__class__, - self, - using_db, - ) - ) - await asyncio.gather(*listeners) + async def _pre_delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None: + await self._wait_for_listeners(Signals.pre_delete, using_db) + + async def _post_delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None: + await self._wait_for_listeners(Signals.post_delete, using_db) async def _pre_save( self, using_db: Optional[BaseDBAsyncClient] = None, update_fields: Optional[Iterable[str]] = None, ) -> None: - listeners = [] - cls_listeners = self._listeners.get(Signals.pre_save, {}).get(self.__class__, []) - for listener in cls_listeners: - listeners.append(listener(self.__class__, self, using_db, update_fields)) - await asyncio.gather(*listeners) + await self._wait_for_listeners(Signals.pre_save, using_db, update_fields) async def _post_save( self, @@ -911,11 +884,7 @@ async def _post_save( created: bool = False, update_fields: Optional[Iterable[str]] = None, ) -> None: - listeners = [] - cls_listeners = self._listeners.get(Signals.post_save, {}).get(self.__class__, []) - for listener in cls_listeners: - listeners.append(listener(self.__class__, self, created, using_db, update_fields)) - await asyncio.gather(*listeners) + await self._wait_for_listeners(Signals.post_save, created, using_db, update_fields) async def save( self,