Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise at next Checkpoint if Non-awaited coroutine found. #176

Closed
wants to merge 10 commits into from
11 changes: 11 additions & 0 deletions trio/_core/_exceptions.py
Expand Up @@ -7,6 +7,7 @@
"WouldBlock",
"Cancelled",
"ResourceBusyError",
"NonAwaitedCoroutines",
]


Expand Down Expand Up @@ -43,6 +44,16 @@ class RunFinishedError(RuntimeError):
pass


@pretend_module_is_trio
class NonAwaitedCoroutines(RuntimeError):
"""Raised by blocking calls if a non-awaited coroutine detected in current task
"""

def __init__(self, *args, coroutines=None, **kwargs):
self.coroutines = set(coroutines)
super().__init__(*args, **kwargs)


@pretend_module_is_trio
class WouldBlock(Exception):
"""Raised by ``X_nowait`` functions if ``X`` would block.
Expand Down
48 changes: 28 additions & 20 deletions trio/_core/_multierror.py
Expand Up @@ -2,8 +2,8 @@
import traceback
import textwrap
import warnings
import types
from contextlib import contextmanager

from ._non_awaited_coroutines import protector

import attr

Expand Down Expand Up @@ -127,25 +127,33 @@ def __enter__(self):
pass

def __exit__(self, etype, exc, tb):

original_exc = exc
if protector.has_unawaited_coroutines():
exc = protector.make_non_awaited_coroutines_error(
protector.pop_all_unawaited_coroutines()
)

if exc is not None:
filtered_exc = MultiError.filter(self._handler, exc)
if filtered_exc is exc:
# Let the interpreter re-raise it
return False
if filtered_exc is None:
# Swallow the exception
return True
# When we raise filtered_exc, Python will unconditionally blow
# away its __context__ attribute and replace it with the original
# exc we caught. So after we raise it, we have to pause it while
# it's in flight to put the correct __context__ back.
old_context = filtered_exc.__context__
try:
raise filtered_exc
finally:
_, value, _ = sys.exc_info()
assert value is filtered_exc
value.__context__ = old_context
exc = MultiError.filter(self._handler, exc)
if exc is None:
# Swallow the exception
return True
if exc is original_exc:
# Let the interpreter re-raise it
return False

# When we raise filtered_exc, Python will unconditionally blow
# away its __context__ attribute and replace it with the original
# exc we caught. So after we raise it, we have to pause it while
# it's in flight to put the correct __context__ back.
old_context = exc.__context__
try:
raise exc
finally:
_, value, _ = sys.exc_info()
assert value is exc
value.__context__ = old_context


class MultiError(BaseException):
Expand Down
146 changes: 146 additions & 0 deletions trio/_core/_non_awaited_coroutines.py
@@ -0,0 +1,146 @@
"""
This module provides utilities to protect against non-awaited coroutine.

Mostly it provide a protector which can install itself with
`sys.set_coroutine_wrapper` and track the creation of all coroutines.

Every now and then we can go over all the coroutines we have reference to, and
check their state. In trio, the trio-runner will do that, at least on every
checkpoint, but that's not the responsibility of this module.

If the coroutine have been awaited at least once, we discard them.

A :class:`CoroProtector` also provide a convenience method
:meth:`await_later(coro)` that return the coroutine unchanged but will ignore it
if not-awaited at next checkpoint.

A default instance of coroutine protector is provided under the attribute `protector`,
and is shared between `trio.run` and the `MultiError.catch`

"""

import sys
import inspect
import textwrap
from ._exceptions import NonAwaitedCoroutines

try:
from tracemalloc import get_object_traceback as _get_tb
except ImportError: # Not available on, for example, PyPy

def _get_tb(obj):
return None


__all__ = ["CoroProtector", "protector"]

################################################################
# Protection against non-awaited coroutines
################################################################


class CoroProtector:
"""
Protector preventing the creation of non-awaited coroutines
between two checkpoints.
"""

def __init__(self):
self._enabled = True
self._pending_test = set()
self._key = object()
self._previous_coro_wrapper = None

def _coro_wrapper(self, coro):
"""
Coroutine wrapper to track creation of coroutines.
"""
if self._enabled:
self._pending_test.add(coro)
if not self._previous_coro_wrapper:
return coro
else:
return self._previous_coro_wrapper(coro)

def await_later(self, coro):
"""
Mark a coroutine as safe to no be awaited, and return it.
"""
self._pending_test.discard(coro)
return coro

def install(self) -> None:
"""install a coroutine wrapper to track created coroutines.

If a coroutine wrapper is already set wrap and call it.
"""
self._previous_coro_wrapper = sys.get_coroutine_wrapper()
sys.set_coroutine_wrapper(self._coro_wrapper)

def uninstall(self) -> None:
assert sys.get_coroutine_wrapper() == self._coro_wrapper
sys.set_coroutine_wrapper(self._previous_coro_wrapper)

def has_unawaited_coroutines(self) -> bool:
"""
Return whether there are unawaited coroutines.

Flush all internally tracked awaited coroutine. Does not discard non-awaited
ones. You need to call `pop_all_unawaited_coroutines` to do that.
"""
return len(self.get_all_unawaited_coroutines()) > 0

def get_all_unawaited_coroutines(self):
state = inspect.getcoroutinestate
self._pending_test = {
coro
for coro in self._pending_test if state(coro) == 'CORO_CREATED'
}
return set(self._pending_test)

def forget(self, coroutines) -> None:
self._pending_test.difference_update(coroutines)

def pop_all_unawaited_coroutines(self):
"""
Check that since last invocation no coroutine has been left unawaited.

Return a list of unawaited coroutines since last call to this function,
and stop tracking them.
"""
coros = self.get_all_unawaited_coroutines()
self._pending_test = set()
return coros

@staticmethod
def make_non_awaited_coroutines_error(coros):
"""
Construct a nice NonAwaitedCoroutines error messages with the origin of the
coroutine if possible.
"""
err = []
for coro in coros:
tb = _get_tb(coro)
if tb:
err.append(' - {coro} ({tb})'.format(coro=coro, tb=tb)
) # pragma: no cover
else:
err.append(' - {coro}'.format(coro=coro))
err = '\n'.join(err)
return NonAwaitedCoroutines(
textwrap.dedent(
'''
One or more coroutines where not awaited:

{err}

Trio has detected that at least a coroutine has not been between awaited
between this checkpoint point and previous one. This is may be due
to a missing `await`.
''' [1:]
).format(err=err),
coroutines=coros
)


protector = CoroProtector()
29 changes: 25 additions & 4 deletions trio/_core/_result.py
@@ -1,6 +1,8 @@
import abc
import attr

from ._non_awaited_coroutines import protector

__all__ = ["Result", "Value", "Error"]


Expand Down Expand Up @@ -30,9 +32,18 @@ def capture(sync_fn, *args):

"""
try:
return Value(sync_fn(*args))
result = Value(sync_fn(*args))
except BaseException as exc:
return Error(exc)
result = Error(exc)
finally:
if protector.has_unawaited_coroutines():
exc = protector.make_non_awaited_coroutines_error(
protector.pop_all_unawaited_coroutines()
)
if type(result) is Error:
exc.__context__ = result.error
result = Error(exc)
return result

@staticmethod
async def acapture(async_fn, *args):
Expand All @@ -43,9 +54,19 @@ async def acapture(async_fn, *args):

"""
try:
return Value(await async_fn(*args))
result = Value(await async_fn(*args))
except BaseException as exc:
return Error(exc)
result = Error(exc)
finally:
if protector.has_unawaited_coroutines():
exc = protector.make_non_awaited_coroutines_error(
protector.pop_all_unawaited_coroutines()
)

if type(result) is Error:
exc.__context__ = result.error
result = Error(exc)
return result

@abc.abstractmethod
def unwrap(self):
Expand Down