Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use db queryset to reduce duplicated code #1652

Merged
merged 5 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Fixed
^^^^^
- Fix `update_or_create` errors when field value changed. (#1584)
- Fix bandit check error (#1643)
- Fix `update_or_create` errors when field value changed. (#1584)
- Fix potential race condition in ConnectionWrapper (#1656)

Changed
^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
DoesNotExist,
IntegrityError,
MultipleObjectsReturned,
ObjectDoesNotExistError,
OperationalError,
ParamsError,
ObjectDoesNotExistError,
ValidationError,
)
from tortoise.expressions import F, Q
Expand Down
2 changes: 1 addition & 1 deletion tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
FieldError,
IntegrityError,
MultipleObjectsReturned,
ParamsError,
NotExistOrMultiple,
ParamsError,
)
from tortoise.expressions import F, RawSQL, Subquery

Expand Down
2 changes: 1 addition & 1 deletion tortoise/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Union, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

if TYPE_CHECKING:
from tortoise import Model, Type
Expand Down
47 changes: 19 additions & 28 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
DoesNotExist,
IncompleteInstanceError,
IntegrityError,
ObjectDoesNotExistError,
OperationalError,
ParamsError,
ObjectDoesNotExistError,
)
from tortoise.fields.base import Field
from tortoise.fields.data import IntField
Expand Down Expand Up @@ -1056,6 +1056,13 @@ async def _create_or_get(
pass
raise exc

@classmethod
def _db_queryset(
cls, using_db: Optional[BaseDBAsyncClient] = None, for_write: bool = False
) -> QuerySet[Self]:
db = using_db or cls._choose_db(for_write)
return cls._meta.manager.get_queryset().using_db(db)

@classmethod
def select_for_update(
cls,
Expand All @@ -1070,10 +1077,7 @@ def select_for_update(
Returns a queryset that will lock rows until the end of the transaction,
generating a SELECT ... FOR UPDATE SQL statement on supported databases.
"""
db = using_db or cls._choose_db(True)
return (
cls._meta.manager.get_queryset().using_db(db).select_for_update(nowait, skip_locked, of)
)
return cls._db_queryset(using_db, for_write=True).select_for_update(nowait, skip_locked, of)

@classmethod
async def update_or_create(
Expand Down Expand Up @@ -1154,10 +1158,7 @@ def bulk_update(
:param batch_size: How many objects are created in a single query
:param using_db: Specific DB connection to use instead of default bound
"""
db = using_db or cls._choose_db(True)
waketzheng marked this conversation as resolved.
Show resolved Hide resolved
return (
cls._meta.manager.get_queryset().using_db(db).bulk_update(objects, fields, batch_size)
)
return cls._db_queryset(using_db, for_write=True).bulk_update(objects, fields, batch_size)

@classmethod
async def in_bulk(
Expand All @@ -1174,8 +1175,7 @@ async def in_bulk(
:param field_name: Must be a unique field
:param using_db: Specific DB connection to use instead of default bound
"""
db = using_db or cls._choose_db()
return await cls._meta.manager.get_queryset().using_db(db).in_bulk(id_list, field_name)
return await cls._db_queryset(using_db).in_bulk(id_list, field_name)

@classmethod
def bulk_create(
Expand Down Expand Up @@ -1214,20 +1214,16 @@ def bulk_create(
:param batch_size: How many objects are created in a single query
:param using_db: Specific DB connection to use instead of default bound
"""
db = using_db or cls._choose_db(True)
return (
cls._meta.manager.get_queryset()
.using_db(db)
.bulk_create(objects, batch_size, ignore_conflicts, update_fields, on_conflict)
return cls._db_queryset(using_db, for_write=True).bulk_create(
objects, batch_size, ignore_conflicts, update_fields, on_conflict
)

@classmethod
def first(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySetSingle[Optional[Self]]:
"""
Generates a QuerySet that returns the first record.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).first()
return cls._db_queryset(using_db).first()

@classmethod
def filter(cls, *args: Q, **kwargs: Any) -> QuerySet[Self]:
Expand Down Expand Up @@ -1263,8 +1259,7 @@ def all(cls, using_db: Optional[BaseDBAsyncClient] = None) -> QuerySet[Self]:
"""
Returns the complete QuerySet.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db)
return cls._db_queryset(using_db)

@classmethod
def get(
Expand All @@ -1284,8 +1279,7 @@ def get(
:raises MultipleObjectsReturned: If provided search returned more than one object.
:raises DoesNotExist: If object can not be found.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).get(*args, **kwargs)
return cls._db_queryset(using_db).get(*args, **kwargs)

@classmethod
def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQuery":
Expand All @@ -1299,8 +1293,7 @@ def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQ
:param using_db: The specific DB connection to use
:param sql: The raw sql.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).raw(sql)
return cls._db_queryset(using_db).raw(sql)

@classmethod
def exists(
Expand All @@ -1317,8 +1310,7 @@ def exists(
:param args: Q functions containing constraints. Will be AND'ed.
:param kwargs: Simple filter constraints.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).filter(*args, **kwargs).exists()
return cls._db_queryset(using_db).filter(*args, **kwargs).exists()

@classmethod
def get_or_none(
Expand All @@ -1335,8 +1327,7 @@ def get_or_none(
:param args: Q functions containing constraints. Will be AND'ed.
:param kwargs: Simple filter constraints.
"""
db = using_db or cls._choose_db()
return cls._meta.manager.get_queryset().using_db(db).get_or_none(*args, **kwargs)
return cls._db_queryset(using_db).get_or_none(*args, **kwargs)

@classmethod
async def fetch_for_list(
Expand Down