Skip to content

Commit

Permalink
Merge pull request #334 from fantix/t313
Browse files Browse the repository at this point in the history
Fixed #313, remove stack when empty
  • Loading branch information
fantix committed Sep 13, 2018
2 parents 7c95264 + 0a8a199 commit 4828084
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 22 deletions.
63 changes: 43 additions & 20 deletions gino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,9 @@ async def release(self, *, permanent=True):
"""
if permanent and self._stack is not None:
for i in range(len(self._stack)):
if self._stack[-1].gino_conn is self:
dbapi_conn = self._stack.pop()
self._stack.rotate(-i)
await dbapi_conn.release(True)
break
else:
self._stack.rotate()
dbapi_conn = self._stack.remove(lambda x: x.gino_conn is self)
if dbapi_conn:
await dbapi_conn.release(True)
else:
raise ValueError('This connection is already released.')
else:
Expand Down Expand Up @@ -493,6 +488,39 @@ async def prepare(self, clause):
clause, (_bypass_no_param,), {}).prepare(clause)


class _ContextualStack:
__slots__ = ('_ctx', '_stack')

def __init__(self, ctx):
self._ctx = ctx
self._stack = ctx.get()
if self._stack is None:
self._stack = collections.deque()
ctx.set(self._stack)

def __bool__(self):
return bool(self._stack)

@property
def top(self):
return self._stack[-1]

def push(self, value):
self._stack.append(value)

def remove(self, checker):
for i in range(len(self._stack)):
if checker(self._stack[-1]):
rv = self._stack.pop()
if self._stack:
self._stack.rotate(-i)
else:
self._ctx.set(None)
return rv
else:
self._stack.rotate(1)


class GinoEngine:
"""
Connects a :class:`~.dialects.base.Pool` and
Expand Down Expand Up @@ -522,7 +550,7 @@ def __init__(self, dialect, pool, loop,
self._dialect = dialect
self._pool = pool
self._loop = loop
self._ctx = ContextVar('gino')
self._ctx = ContextVar('gino', default=None)

@property
def dialect(self):
Expand Down Expand Up @@ -608,14 +636,10 @@ def acquire(self, *, timeout=None, reuse=False, lazy=False, reusable=True):
self._acquire, timeout, reuse, lazy, reusable))

async def _acquire(self, timeout, reuse, lazy, reusable):
try:
stack = self._ctx.get()
except LookupError:
stack = collections.deque()
self._ctx.set(stack)
stack = _ContextualStack(self._ctx)
if reuse and stack:
dbapi_conn = _ReusingDBAPIConnection(self._dialect.cursor_cls,
stack[-1])
stack.top)
reusable = False
else:
dbapi_conn = _DBAPIConnection(self._dialect.cursor_cls, self._pool)
Expand All @@ -626,7 +650,7 @@ async def _acquire(self, timeout, reuse, lazy, reusable):
if not lazy:
await dbapi_conn.acquire(timeout=timeout)
if reusable:
stack.append(dbapi_conn)
stack.push(dbapi_conn)
return rv

@property
Expand All @@ -638,10 +662,9 @@ def current_connection(self):
:return: :class:`.GinoConnection`
"""
try:
return self._ctx.get()[-1].gino_conn
except (LookupError, IndexError):
pass
stack = self._ctx.get()
if stack:
return stack[-1].gino_conn

async def close(self):
"""
Expand Down
29 changes: 27 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,15 @@ async def test_lazy(mocker):
async with engine.acquire(lazy=True):
assert qsize(engine) == init_size
assert len(engine._ctx.get()) == 1
assert len(engine._ctx.get()) == 0
assert engine._ctx.get() is None
assert qsize(engine) == init_size
async with engine.acquire(lazy=True):
assert qsize(engine) == init_size
assert len(engine._ctx.get()) == 1
assert await engine.scalar('select 1')
assert qsize(engine) == init_size - 1
assert len(engine._ctx.get()) == 1
assert len(engine._ctx.get()) == 0
assert engine._ctx.get() is None
assert qsize(engine) == init_size

loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -367,3 +367,28 @@ async def test_ssl():

e = await gino.create_engine(PG_URL, ssl=ctx)
await e.close()


async def test_issue_313(bind):
assert bind._ctx.get() is None

async with db.acquire():
pass

assert bind._ctx.get() is None

async def task():
async with db.acquire(reuse=True):
await db.scalar('SELECT now()')

await asyncio.gather(*[task() for _ in range(5)])

assert bind._ctx.get() is None

async def task():
async with db.transaction():
await db.scalar('SELECT now()')

await asyncio.gather(*[task() for _ in range(5)])

assert bind._ctx.get() is None

0 comments on commit 4828084

Please sign in to comment.