Skip to content

Commit

Permalink
Merge c3babda into 45ee1d1
Browse files Browse the repository at this point in the history
  • Loading branch information
grigi committed Feb 29, 2020
2 parents 45ee1d1 + c3babda commit 01dc552
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ future
- Fix default type of ``JSONField``
- Install on Windows does not require a C compiler any more.
- Fix ``IntegrityError`` with unique field and ``get_or_create``

0.15.17
-------
Expand Down
4 changes: 2 additions & 2 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def test_nonconcurrent_get_or_create(self):
for una in unas:
self.assertEqual(una[0], unas[0][0])

@test.expectedFailure
@test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well")
async def test_concurrent_get_or_create(self):
unas = await asyncio.gather(*[UniqueName.get_or_create(name="d") for _ in range(10)])
una_created = [una[1] for una in unas if una[1] is True]
Expand Down Expand Up @@ -76,7 +76,7 @@ async def test_nonconcurrent_get_or_create(self):
for una in unas:
self.assertEqual(una[0], unas[0][0])

@test.skip("Crashes with MySQL & PostgreSQL")
@test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well")
async def test_concurrent_get_or_create(self):
unas = await asyncio.gather(*[UniqueName.get_or_create(name="b") for _ in range(10)])
una_created = [una[1] for una in unas if una[1] is True]
Expand Down
3 changes: 0 additions & 3 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import sys

from tests.testmodels import IntFields, MinRelation, Tournament
from tortoise.contrib import test
from tortoise.exceptions import DoesNotExist, FieldError, IntegrityError, MultipleObjectsReturned
Expand Down Expand Up @@ -278,7 +276,6 @@ async def test_all_flat_values_list(self):
):
await IntFields.all().values_list(flat=True)

@test.skipIf(sys.version_info < (3, 6), "Class fields not sorted in 3.5")
async def test_all_values_list(self):
data = await IntFields.all().order_by("id").values_list()
self.assertEqual(data[2], (self.intfields[2].id, 16, None))
Expand Down
5 changes: 3 additions & 2 deletions tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
BaseTransactionWrapper,
Capabilities,
ConnectionWrapper,
NestedTransactionContext,
NestedTransactionPooledContext,
PoolConnectionWrapper,
TransactionContext,
TransactionContextPooled,
Expand Down Expand Up @@ -192,14 +192,15 @@ class TransactionWrapper(AsyncpgDBClient, BaseTransactionWrapper):
def __init__(self, connection: AsyncpgDBClient) -> None:
self._connection: asyncpg.Connection = connection._connection
self._lock = asyncio.Lock()
self._trxlock = asyncio.Lock()
self.log = connection.log
self.connection_name = connection.connection_name
self.transaction: Transaction = None
self._finalized = False
self._parent = connection

def _in_transaction(self) -> "TransactionContext":
return NestedTransactionContext(self)
return NestedTransactionPooledContext(self)

def acquire_connection(self) -> "ConnectionWrapper":
return ConnectionWrapper(self._connection, self._lock)
Expand Down
17 changes: 14 additions & 3 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

class NestedTransactionContext(TransactionContext):
async def __aenter__(self):
await self.connection.start()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand All @@ -277,8 +276,20 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()


class NestedTransactionPooledContext(TransactionContext):
async def __aenter__(self):
await self.lock.acquire()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.lock.release()
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()


class PoolConnectionWrapper:
Expand Down
5 changes: 3 additions & 2 deletions tortoise/backends/mysql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BaseTransactionWrapper,
Capabilities,
ConnectionWrapper,
NestedTransactionContext,
NestedTransactionPooledContext,
PoolConnectionWrapper,
TransactionContext,
TransactionContextPooled,
Expand Down Expand Up @@ -208,13 +208,14 @@ def __init__(self, connection: MySQLClient) -> None:
self.connection_name = connection.connection_name
self._connection: aiomysql.Connection = connection._connection
self._lock = asyncio.Lock()
self._trxlock = asyncio.Lock()
self.log = connection.log
self._finalized: Optional[bool] = None
self.fetch_inserted = connection.fetch_inserted
self._parent = connection

def _in_transaction(self) -> "TransactionContext":
return NestedTransactionContext(self)
return NestedTransactionPooledContext(self)

def acquire_connection(self) -> ConnectionWrapper:
return ConnectionWrapper(self._connection, self._lock)
Expand Down
25 changes: 19 additions & 6 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from pypika import Order, Query, Table

from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.exceptions import ConfigurationError, OperationalError
from tortoise.exceptions import (
ConfigurationError,
IntegrityError,
OperationalError,
TransactionManagementError,
)
from tortoise.fields.base import Field
from tortoise.fields.data import IntField
from tortoise.fields.relational import (
Expand All @@ -22,7 +27,7 @@
from tortoise.filters import get_filters_for_field
from tortoise.functions import Function
from tortoise.queryset import Q, QuerySet, QuerySetSingle
from tortoise.transactions import current_transaction_map
from tortoise.transactions import current_transaction_map, in_transaction

MODEL = TypeVar("MODEL", bound="Model")
# TODO: Define Filter type object. Possibly tuple?
Expand Down Expand Up @@ -752,10 +757,18 @@ async def get_or_create(
"""
if not defaults:
defaults = {}
instance = await cls.filter(**kwargs).first()
if instance:
return instance, False
return await cls.create(**defaults, **kwargs, using_db=using_db), True
db = using_db if using_db else cls._meta.db
async with in_transaction(connection_name=db.connection_name):
instance = await cls.filter(**kwargs).first()
if instance:
return instance, False
try:
return await cls.create(**defaults, **kwargs), True
except (IntegrityError, TransactionManagementError):
# Let transaction close
pass
# Try after transaction in case transaction error
return await cls.get(**kwargs), False

@classmethod
async def create(cls: Type[MODEL], **kwargs: Any) -> MODEL:
Expand Down

0 comments on commit 01dc552

Please sign in to comment.