Skip to content

Commit

Permalink
refactor: share logic for signal listeners (#1630)
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed May 26, 2024
1 parent 8f83a75 commit 88acc8d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 67 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ check: deps build
ifneq ($(shell which black),)
black --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
endif
ruff $(checkfiles)
ruff check $(checkfiles)
mypy $(checkfiles)
#pylint -d C,W,R $(checkfiles)
#bandit -r $(checkfiles)make
Expand All @@ -34,7 +34,7 @@ lint: deps build
ifneq ($(shell which black),)
black --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
endif
ruff $(checkfiles)
ruff check $(checkfiles)
mypy $(checkfiles)
#pylint $(checkfiles)
bandit -r $(checkfiles)
Expand Down
11 changes: 8 additions & 3 deletions tortoise/fields/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import warnings
from enum import Enum
from typing import (
Expand All @@ -23,11 +24,15 @@
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model

VALUE = TypeVar("VALUE")
if sys.version_info >= (3, 11):
from enum import StrEnum
else: # pragma: no cover

class StrEnum(str, Enum):
__str__ = str.__str__


class StrEnum(str, Enum):
__str__ = str.__str__
VALUE = TypeVar("VALUE")


class OnDelete(StrEnum):
Expand Down
93 changes: 31 additions & 62 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,11 @@ def finalise_fields(self) -> None:
| self.o2o_fields
)

generated_fields = []
for field in self.fields_map.values():
if not field.generated:
continue
generated_fields.append(field.source_field or field.model_field_name)
generated_fields = [
(field.source_field or field.model_field_name)
for field in self.fields_map.values()
if field.generated
]
self.generated_db_fields = tuple(generated_fields)

self._ordering_validated = True
Expand All @@ -343,29 +343,28 @@ def _generate_lazy_fk_m2m_fields(self) -> None:
fk_field_object: ForeignKeyFieldInstance = self.fields_map[key] # type: ignore
relation_field = fk_field_object.source_field
to_field = fk_field_object.to_field_instance.model_field_name
property_kwargs = dict(
_key=_key,
relation_field=relation_field,
to_field=to_field,
)
setattr(
self._model,
key,
property(
partial(
_fk_getter,
_key=_key,
ftype=fk_field_object.related_model,
relation_field=relation_field,
to_field=to_field,
**property_kwargs,
),
partial(
_fk_setter,
_key=_key,
relation_field=relation_field,
to_field=to_field,
**property_kwargs,
),
partial(
_fk_setter,
value=None,
_key=_key,
relation_field=relation_field,
to_field=to_field,
**property_kwargs,
),
),
)
Expand Down Expand Up @@ -394,29 +393,28 @@ def _generate_lazy_fk_m2m_fields(self) -> None:
o2o_field_object: OneToOneFieldInstance = self.fields_map[key] # type: ignore
relation_field = o2o_field_object.source_field
to_field = o2o_field_object.to_field_instance.model_field_name
property_kwargs = dict(
_key=_key,
relation_field=relation_field,
to_field=to_field,
)
setattr(
self._model,
key,
property(
partial(
_fk_getter,
_key=_key,
ftype=o2o_field_object.related_model,
relation_field=relation_field,
to_field=to_field,
**property_kwargs,
),
partial(
_fk_setter,
_key=_key,
relation_field=relation_field,
to_field=to_field,
**property_kwargs,
),
partial(
_fk_setter,
value=None,
_key=_key,
relation_field=relation_field,
to_field=to_field,
**property_kwargs,
),
),
)
Expand Down Expand Up @@ -862,60 +860,31 @@ async def _set_async_default_field(self) -> None:
setattr(self, k, await v())
self._await_when_save = {}

async def _pre_delete(
self,
using_db: Optional[BaseDBAsyncClient] = None,
) -> None:
listeners = []
cls_listeners = self._listeners.get(Signals.pre_delete, {}).get(self.__class__, [])
for listener in cls_listeners:
listeners.append(
listener(
self.__class__,
self,
using_db,
)
)
async def _wait_for_listeners(self, signal: Signals, *listener_args) -> None:
cls_listeners = self._listeners.get(signal, {}).get(self.__class__, [])
listeners = [listener(self.__class__, self, *listener_args) for listener in cls_listeners]
await asyncio.gather(*listeners)

async def _post_delete(
self,
using_db: Optional[BaseDBAsyncClient] = None,
) -> None:
listeners = []
cls_listeners = self._listeners.get(Signals.post_delete, {}).get(self.__class__, [])
for listener in cls_listeners:
listeners.append(
listener(
self.__class__,
self,
using_db,
)
)
await asyncio.gather(*listeners)
async def _pre_delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None:
await self._wait_for_listeners(Signals.pre_delete, using_db)

async def _post_delete(self, using_db: Optional[BaseDBAsyncClient] = None) -> None:
await self._wait_for_listeners(Signals.post_delete, using_db)

async def _pre_save(
self,
using_db: Optional[BaseDBAsyncClient] = None,
update_fields: Optional[Iterable[str]] = None,
) -> None:
listeners = []
cls_listeners = self._listeners.get(Signals.pre_save, {}).get(self.__class__, [])
for listener in cls_listeners:
listeners.append(listener(self.__class__, self, using_db, update_fields))
await asyncio.gather(*listeners)
await self._wait_for_listeners(Signals.pre_save, using_db, update_fields)

async def _post_save(
self,
using_db: Optional[BaseDBAsyncClient] = None,
created: bool = False,
update_fields: Optional[Iterable[str]] = None,
) -> None:
listeners = []
cls_listeners = self._listeners.get(Signals.post_save, {}).get(self.__class__, [])
for listener in cls_listeners:
listeners.append(listener(self.__class__, self, created, using_db, update_fields))
await asyncio.gather(*listeners)
await self._wait_for_listeners(Signals.post_save, created, using_db, update_fields)

async def save(
self,
Expand Down

0 comments on commit 88acc8d

Please sign in to comment.