diff --git a/gino/local.py b/gino/local.py index ede6d58c..119a21a1 100644 --- a/gino/local.py +++ b/gino/local.py @@ -80,9 +80,34 @@ def __getattr__(self, item): return getattr(self._task, item) -def reset_local(task): - m = getattr(task, 'with_local_reset', None) - return task if m is None else m() +def reset_local(coro_or_future, *, loop=None): + """Reset local to empty string within given routine. + + This works for: + + - newly created tasks + - coroutines + - awaitables + + For coroutines and awaitables, new tasks will be created with the give loop + or current loop if not given. + + This doesn't work if: + + - task local is disabled for current/given loop + - given task is already running + - given futures are not instances of the internal `TaskWrapper` or subclass + + It is only the "reset local" part that is not working, the given future is + returned untouched. + """ + if isinstance(coro_or_future, TaskWrapper): + return coro_or_future.with_local_reset() + elif not asyncio.isfuture(coro_or_future): + return reset_local(asyncio.ensure_future(coro_or_future, loop=loop)) + else: + # we don't know how to reset local for random Future objects + return coro_or_future def task_factory(loop, coro): diff --git a/tests/test_local.py b/tests/test_local.py index 0b149f96..c1b6a215 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -46,8 +46,10 @@ async def test_reset(task_local): result = [] await asyncio.ensure_future(sub(result)).with_local_reset() assert result[0] is None - await reset_local(asyncio.ensure_future(sub(result))) - assert result[1] is None + await sub(result) + assert result[1] == 123 + await reset_local(sub(result)) + assert result[2] is None async def test_reset_disabled():