Skip to content

Commit

Permalink
Merge pull request #1513 from guilledk/issue1244
Browse files Browse the repository at this point in the history
Fixed unhelpful error message in from_thread_run functions.
  • Loading branch information
oremanj committed May 21, 2020
2 parents 54151c5 + f77c200 commit a266dfc
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 161 deletions.
1 change: 1 addition & 0 deletions newsfragments/1244.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a helpful error message if an async function is passed to `trio.from_thread.run_sync` or a sync function to `trio.from_thread.run`.
87 changes: 5 additions & 82 deletions trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sniffio import current_async_library_cvar

import attr
from async_generator import isasyncgen
from sortedcontainers import SortedDict
from outcome import Error, Value, capture

Expand All @@ -36,7 +35,7 @@
)
from .. import _core
from .._deprecate import deprecated
from .._util import Final, NoPublicConstructor
from .._util import Final, NoPublicConstructor, coroutine_or_error

_NO_SEND = object()

Expand Down Expand Up @@ -1247,86 +1246,7 @@ def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False):
# Call the function and get the coroutine object, while giving helpful
# errors for common mistakes.
######

def _return_value_looks_like_wrong_library(value):
# Returned by legacy @asyncio.coroutine functions, which includes
# a surprising proportion of asyncio builtins.
if isinstance(value, collections.abc.Generator):
return True
# The protocol for detecting an asyncio Future-like object
if getattr(value, "_asyncio_future_blocking", None) is not None:
return True
# This janky check catches tornado Futures and twisted Deferreds.
# By the time we're calling this function, we already know
# something has gone wrong, so a heuristic is pretty safe.
if value.__class__.__name__ in ("Future", "Deferred"):
return True
return False

try:
coro = async_fn(*args)
except TypeError:
# Give good error for: nursery.start_soon(trio.sleep(1))
if isinstance(async_fn, collections.abc.Coroutine):
raise TypeError(
"Trio was expecting an async function, but instead it got "
"a coroutine object {async_fn!r}\n"
"\n"
"Probably you did something like:\n"
"\n"
" trio.run({async_fn.__name__}(...)) # incorrect!\n"
" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n"
"\n"
"Instead, you want (notice the parentheses!):\n"
"\n"
" trio.run({async_fn.__name__}, ...) # correct!\n"
" nursery.start_soon({async_fn.__name__}, ...) # correct!"
.format(async_fn=async_fn)
) from None

# Give good error for: nursery.start_soon(future)
if _return_value_looks_like_wrong_library(async_fn):
raise TypeError(
"Trio was expecting an async function, but instead it got "
"{!r} – are you trying to use a library written for "
"asyncio/twisted/tornado or similar? That won't work "
"without some sort of compatibility shim."
.format(async_fn)
) from None

raise

# We can't check iscoroutinefunction(async_fn), because that will fail
# for things like functools.partial objects wrapping an async
# function. So we have to just call it and then check whether the
# return value is a coroutine object.
if not isinstance(coro, collections.abc.Coroutine):
# Give good error for: nursery.start_soon(func_returning_future)
if _return_value_looks_like_wrong_library(coro):
raise TypeError(
"start_soon got unexpected {!r} – are you trying to use a "
"library written for asyncio/twisted/tornado or similar? "
"That won't work without some sort of compatibility shim."
.format(coro)
)

if isasyncgen(coro):
raise TypeError(
"start_soon expected an async function but got an async "
"generator {!r}".format(coro)
)

# Give good error for: nursery.start_soon(some_sync_fn)
raise TypeError(
"Trio expected an async function, but {!r} appears to be "
"synchronous".format(
getattr(async_fn, "__qualname__", async_fn)
)
)

######
# Set up the Task object
######
coro = coroutine_or_error(async_fn, *args)

if name is None:
name = async_fn
Expand All @@ -1353,6 +1273,9 @@ async def python_wrapper(orig_coro):
LOCALS_KEY_KI_PROTECTION_ENABLED, system_task
)

######
# Set up the Task object
######
task = Task._create(
coro=coro,
parent_nursery=nursery,
Expand Down
87 changes: 17 additions & 70 deletions trio/_core/tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import sniffio
import pytest

from .tutil import slow, check_sequence_matches, gc_collect_harder
from .tutil import (
slow, check_sequence_matches, gc_collect_harder,
ignore_coroutine_never_awaited_warnings
)

from ... import _core
from ..._threads import to_thread_run_sync
from ..._timeouts import sleep, fail_after
Expand All @@ -33,24 +37,6 @@ async def sleep_forever():
return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)


# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings():
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="coroutine '.*' was never awaited"
)
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()


def test_basic():
async def trivial(x):
return x
Expand Down Expand Up @@ -1696,8 +1682,6 @@ async def test_current_effective_deadline(mock_clock):
assert _core.current_effective_deadline() == inf


# @coroutine is deprecated since python 3.8, which is fine with us.
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
def test_nice_error_on_bad_calls_to_run_or_spawn():
def bad_call_run(*args):
_core.run(*args)
Expand All @@ -1709,59 +1693,22 @@ async def main():

_core.run(main)

class Deferred:
"Just kidding"

with ignore_coroutine_never_awaited_warnings():
for bad_call in bad_call_run, bad_call_spawn:

async def f(): # pragma: no cover
pass

with pytest.raises(TypeError) as excinfo:
bad_call(f())
assert "expecting an async function" in str(excinfo.value)

import asyncio

@asyncio.coroutine
def generator_based_coro(): # pragma: no cover
yield from asyncio.sleep(1)

with pytest.raises(TypeError) as excinfo:
bad_call(generator_based_coro())
assert "asyncio" in str(excinfo.value)
for bad_call in bad_call_run, bad_call_spawn:

with pytest.raises(TypeError) as excinfo:
bad_call(asyncio.Future())
assert "asyncio" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(lambda: asyncio.Future())
assert "asyncio" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(Deferred())
assert "twisted" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(lambda: Deferred())
assert "twisted" in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
bad_call(len, [1, 2, 3])
assert "appears to be synchronous" in str(excinfo.value)
async def f(): # pragma: no cover
pass

async def async_gen(arg): # pragma: no cover
yield
with pytest.raises(TypeError, match="expecting an async function"):
bad_call(f())

with pytest.raises(TypeError) as excinfo:
bad_call(async_gen, 0)
msg = "expected an async function but got an async generator"
assert msg in str(excinfo.value)
async def async_gen(arg): # pragma: no cover
yield arg

# Make sure no references are kept around to keep anything alive
del excinfo
with pytest.raises(
TypeError,
match="expected an async function but got an async generator"
):
bad_call(async_gen, 0)


def test_calling_asyncio_function_gives_nice_error():
Expand Down
20 changes: 20 additions & 0 deletions trio/_core/tests/tutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os

import pytest
import warnings
from contextlib import contextmanager

import gc

Expand Down Expand Up @@ -52,6 +54,24 @@ def gc_collect_harder():
gc.collect()


# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings():
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="coroutine '.*' was never awaited"
)
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()


# template is like:
# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3]
def check_sequence_matches(seq, template):
Expand Down
22 changes: 17 additions & 5 deletions trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from itertools import count

import attr
import inspect
import outcome

import trio

from ._sync import CapacityLimiter
from ._core import enable_ki_protection, disable_ki_protection, RunVar, TrioToken
from ._util import coroutine_or_error

# Global due to Threading API, thread local storage for trio token
TOKEN_LOCAL = threading.local()
Expand Down Expand Up @@ -365,6 +367,7 @@ def from_thread_run(afn, *args, trio_token=None):
which would otherwise cause a deadlock.
AttributeError: if no ``trio_token`` was provided, and we can't infer
one from context.
TypeError: if ``afn`` is not an asynchronous function.
**Locating a Trio Token**: There are two ways to specify which
`trio.run` loop to reenter:
Expand All @@ -380,7 +383,8 @@ def from_thread_run(afn, *args, trio_token=None):
def callback(q, afn, args):
@disable_ki_protection
async def unprotected_afn():
return await afn(*args)
coro = coroutine_or_error(afn, *args)
return await coro

async def await_in_trio_thread_task():
q.put_nowait(await outcome.acapture(unprotected_afn))
Expand All @@ -403,13 +407,11 @@ def from_thread_run_sync(fn, *args, trio_token=None):
Raises:
RunFinishedError: if the corresponding call to `trio.run` has
already completed.
Cancelled: if the corresponding call to `trio.run` completes
while ``afn(*args)`` is running, then ``afn`` is likely to raise
:exc:`trio.Cancelled`, and this will propagate out into
RuntimeError: if you try calling this from inside the Trio thread,
which would otherwise cause a deadlock.
AttributeError: if no ``trio_token`` was provided, and we can't infer
one from context.
TypeError: if ``fn`` is an async function.
**Locating a Trio Token**: There are two ways to specify which
`trio.run` loop to reenter:
Expand All @@ -425,7 +427,17 @@ def from_thread_run_sync(fn, *args, trio_token=None):
def callback(q, fn, args):
@disable_ki_protection
def unprotected_fn():
return fn(*args)
ret = fn(*args)

if inspect.iscoroutine(ret):
# Manually close coroutine to avoid RuntimeWarnings
ret.close()
raise TypeError(
"Trio expected a sync function, but {!r} appears to be "
"asynchronous".format(getattr(fn, "__qualname__", fn))
)

return ret

res = outcome.capture(unprotected_fn)
q.put_nowait(res)
Expand Down
Loading

0 comments on commit a266dfc

Please sign in to comment.