Skip to content

Commit

Permalink
refactor: use db queryset to reduce duplicated code (#1652)
Browse files Browse the repository at this point in the history
* refactor: use db queryset to reduce duplicated code

* fix missing for write

* Sort imports

* Update changelog
  • Loading branch information
waketzheng authored Jun 17, 2024
1 parent 230eb41 commit 3568aea
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 32 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Fixed
^^^^^
- Fix `update_or_create` errors when field value changed. (#1584)
- Fix bandit check error (#1643)
- Fix `update_or_create` errors when field value changed. (#1584)
- Fix potential race condition in ConnectionWrapper (#1656)

Changed
^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
DoesNotExist,
IntegrityError,
MultipleObjectsReturned,
ObjectDoesNotExistError,
OperationalError,
ParamsError,
ObjectDoesNotExistError,
ValidationError,
)
from tortoise.expressions import F, Q
Expand Down
2 changes: 1 addition & 1 deletion tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
FieldError,
IntegrityError,
MultipleObjectsReturned,
ParamsError,
NotExistOrMultiple,
ParamsError,
)
from tortoise.expressions import F, RawSQL, Subquery

Expand Down
2 changes: 1 addition & 1 deletion tortoise/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Union, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

if TYPE_CHECKING:
from tortoise import Model, Type
Expand Down
47 changes: 19 additions & 28 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
DoesNotExist,
IncompleteInstanceError,
IntegrityError,
ObjectDoesNotExistError,
OperationalError,
ParamsError,
ObjectDoesNotExistError,
)
from tortoise.fields.base import Field
from tortoise.fields.data import IntField
Expand Down Expand Up @@ -1056,6 +1056,13 @@ async def _create_or_get(
pass
raise exc

@classmethod
def _db_queryset(
cls, using_db: Optional[BaseDBAsyncClient] = None, for_write: bool = False
) -> QuerySet[Self]:
db = using_db or cls._choose_db(for_write)
return cls._meta.manager.get_queryset().using_db(db)

@classmethod
def select_for_update(
cls,
Expand All @@ -1070,10 +1077,7 @@ def select_for_update(
Returns a queryset that will lock rows until the end of the transaction,
generating a SELECT ... FOR UPDATE SQL statement on supported databases.
"""
db = using_db or cls._choose_db(True)
return (
cls._meta.manager.get_queryset().using_db(db).select_for_update(nowait, skip_locked, of)
)
return cls._db_queryset(using_db, for_write=True).select_for_update(nowait, skip_locked, of)

@classmethod
async def update_or_create(
Expand Down Expand Up @@ -1154,10 +1158,7 @@ def bulk_update(
:param batch_size: How many objects are created in a single query
:param using_db: Specific DB connection to use instead of default bound
"""
db = using_db or cls._choose_db(True)
return (
cls._meta.manager.get_queryset().using_db(db).bulk_update(objects, fields, batch_size)
)
return cls._db_queryset(using_db, for_write=True).bulk_update(objects, fields, batch_size)

@classmethod
async def in_bulk(
Expand All @@ -1174,8 +1175,7 @@ async def in_bulk(
:param field_name: Must be a unique field
:param using_db: Specific DB connection to use instead of default bound
"""
db = using_db or cls._choose_db()
return await cls._meta.manager.get_queryset().using_db(db).in_bulk(id_list, field_name)
return await cls._db_queryset(using_db).in_bulk(id_list, field_name)

@classmethod
def bulk_create(
Expand Down Expand Up @@ -1214,20 +1214,16 @@ def bulk_create(
:param batch_size: How many objects are created in a single query
:param using_db: Specific DB connection to use instead of default bound
"""
db = using_db or cls._choose_db(True)
return (
cls._meta.manager.get_queryset()
.using_db(db)
.bulk_create(objects, batch_size, ignore_conflicts, update_fields, on_conflict)
return cls._db_queryset(using_db, for_write=True).bulk_create(
objects, batch_size, ignore_conflicts, update_fields, on_conflict
)

@classmethod
def first(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySetSingle[Optional[Self]]:
"""
Generates a QuerySet that returns the first record.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).first()
return cls._db_queryset(using_db).first()

@classmethod
def filter(cls, *args: Q, **kwargs: Any) -> QuerySet[Self]:
Expand Down Expand Up @@ -1263,8 +1259,7 @@ def all(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySet[Self]:
"""
Returns the complete QuerySet.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db)
return cls._db_queryset(using_db)

@classmethod
def get(
Expand All @@ -1284,8 +1279,7 @@ def get(
:raises MultipleObjectsReturned: If provided search returned more than one object.
:raises DoesNotExist: If object can not be found.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).get(*args, **kwargs)
return cls._db_queryset(using_db).get(*args, **kwargs)

@classmethod
def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQuery":
Expand All @@ -1299,8 +1293,7 @@ def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQ
:param using_db: The specific DB connection to use
:param sql: The raw sql.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).raw(sql)
return cls._db_queryset(using_db).raw(sql)

@classmethod
def exists(
Expand All @@ -1317,8 +1310,7 @@ def exists(
:param args: Q functions containing constraints. Will be AND'ed.
:param kwargs: Simple filter constraints.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).filter(*args, **kwargs).exists()
return cls._db_queryset(using_db).filter(*args, **kwargs).exists()

@classmethod
def get_or_none(
Expand All @@ -1335,8 +1327,7 @@ def get_or_none(
:param args: Q functions containing constraints. Will be AND'ed.
:param kwargs: Simple filter constraints.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).get_or_none(*args, **kwargs)
return cls._db_queryset(using_db).get_or_none(*args, **kwargs)

@classmethod
async def fetch_for_list(
Expand Down

0 comments on commit 3568aea

Please sign in to comment.